Unverified Commit 842aba50 authored by dtc's avatar dtc Committed by GitHub
Browse files

[P/D] Introduce Mooncake Transfer Engine as kv_connector (#24718)


Signed-off-by: default avatarTianchen Ding <dtcccc@linux.alibaba.com>
Signed-off-by: default avatardtc <dtcccc@linux.alibaba.com>
Co-authored-by: default avatarNicolò Lucchesi <nicolo.lucchesi@gmail.com>
parent f2f4cea6
# MooncakeConnector Usage Guide
## About Mooncake
Mooncake aims to enhance the inference efficiency of large language models (LLMs), especially in slow object storage environments, by constructing a multi-level caching pool on high-speed interconnected DRAM/SSD resources. Compared to traditional caching systems, Mooncake utilizes (GPUDirect) RDMA technology to transfer data directly in a zero-copy manner, while maximizing the use of multi-NIC resources on a single machine.
For more details about Mooncake, please refer to [Mooncake project](https://github.com/kvcache-ai/Mooncake) and [Mooncake documents](https://kvcache-ai.github.io/Mooncake/).
## Prerequisites
### Installation
Install mooncake through pip: `uv pip install mooncake-transfer-engine`.
Refer to [Mooncake official repository](https://github.com/kvcache-ai/Mooncake) for more installation instructions
## Usage
### Prefiller Node (192.168.0.2)
```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --port 8010 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}'
```
### Decoder Node (192.168.0.3)
```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --port 8020 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}'
```
### Proxy
```bash
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --prefiller-host 192.168.0.2 --prefiller-port 8010 --decoder-host 192.168.0.3 --decoder-port 8020
```
> NOTE: The Mooncake Connector currently uses the proxy from nixl_integration. This will be replaced with a self-developed proxy in the future.
Now you can send requests to the proxy server through port 8000.
## Environment Variables
- `VLLM_MOONCAKE_BOOTSTRAP_PORT`: Port for Mooncake bootstrap server
- Default: 8998
- Required only for prefiller instances
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank
- Used for the decoder notifying the prefiller
- `VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional)
- Default: 480
- If a request is aborted and the decoder has not yet notified the prefiller, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
## KV Role Options
- **kv_producer**: For prefiller instances that generate KV caches
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.
...@@ -190,3 +190,8 @@ KVConnectorFactory.register_connector( ...@@ -190,3 +190,8 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector", "vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
"DecodeBenchConnector", "DecodeBenchConnector",
) )
KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector",
"MooncakeConnector",
)
...@@ -4,10 +4,13 @@ ...@@ -4,10 +4,13 @@
KV cache helper for store. KV cache helper for store.
""" """
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -181,3 +184,124 @@ def copy_kv_blocks( ...@@ -181,3 +184,124 @@ def copy_kv_blocks(
src_tensor = src_kv_caches[layer_name] src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name] dst_tensor = dst_kv_caches[layer_name]
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
@dataclass
class TpKVTopology:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_rank: int
remote_tp_size: dict[str, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: str
remote_block_size: dict[str, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
@property
def tp_size(self) -> int:
return self.remote_tp_size[self.engine_id]
@property
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
def tp_ratio(
self,
remote_tp_size: int,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
"""
Calculate the block size ratio between local and remote TP.
"""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
return self.block_size // remote_block_size
def tp_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> float:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: str) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
self,
remote_tp_size: int,
) -> int:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
def get_target_remote_rank_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)
...@@ -20,10 +20,10 @@ import torch ...@@ -20,10 +20,10 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp, CopyBlocksOp,
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -668,128 +668,6 @@ class NixlConnectorScheduler: ...@@ -668,128 +668,6 @@ class NixlConnectorScheduler:
class NixlConnectorWorker: class NixlConnectorWorker:
"""Implementation of Worker side methods""" """Implementation of Worker side methods"""
@dataclass
class TpKVTopology:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_rank: int
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
)
@property
def tp_size(self) -> int:
return self.remote_tp_size[self.engine_id]
@property
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
def tp_ratio(
self,
remote_tp_size: int,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
"""
Calculate the block size ratio between local and remote TP.
"""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
return self.block_size // remote_block_size
def tp_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> float:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
self,
remote_tp_size: int,
) -> int:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
def get_target_remote_rank_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)
def __init__(self, vllm_config: VllmConfig, engine_id: str): def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None: if NixlWrapper is None:
logger.error("NIXL is not available") logger.error("NIXL is not available")
...@@ -958,7 +836,7 @@ class NixlConnectorWorker: ...@@ -958,7 +836,7 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats() self.xfer_stats = NixlKVConnectorStats()
self.kv_topo = self.TpKVTopology( self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
engine_id=self.engine_id, engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state remote_tp_size=self._tp_size, # shared state
......
...@@ -175,6 +175,7 @@ if TYPE_CHECKING: ...@@ -175,6 +175,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600 VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600
VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998
VLLM_ALL2ALL_BACKEND: Literal[ VLLM_ALL2ALL_BACKEND: Literal[
"naive", "naive",
"pplx", "pplx",
...@@ -197,6 +198,7 @@ if TYPE_CHECKING: ...@@ -197,6 +198,7 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False
...@@ -1260,6 +1262,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1260,6 +1262,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int( "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int(
os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600") os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600")
), ),
# Port used for Mooncake handshake between remote agents.
"VLLM_MOONCAKE_BOOTSTRAP_PORT": lambda: int(
os.getenv("VLLM_MOONCAKE_BOOTSTRAP_PORT", "8998")
),
# all2all backend for vllm's expert parallel communication # all2all backend for vllm's expert parallel communication
# Available options: # Available options:
# - "naive": naive all2all implementation using broadcasts # - "naive": naive all2all implementation using broadcasts
...@@ -1369,6 +1375,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1369,6 +1375,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int(
os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480")
), ),
# Timeout (in seconds) for MooncakeConnector in PD disaggregated setup.
"VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int(
os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480")
),
# Controls whether or not to use cudnn prefill # Controls whether or not to use cudnn prefill
"VLLM_USE_CUDNN_PREFILL": lambda: bool( "VLLM_USE_CUDNN_PREFILL": lambda: bool(
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment