Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
...@@ -32,7 +32,7 @@ class MMMeta: ...@@ -32,7 +32,7 @@ class MMMeta:
@dataclass @dataclass
class ECSharedStorageConnectorMetadata(ECConnectorMetadata): class ECExampleConnectorMetadata(ECConnectorMetadata):
mm_datas: list[MMMeta] mm_datas: list[MMMeta]
def __init__(self): def __init__(self):
...@@ -42,7 +42,7 @@ class ECSharedStorageConnectorMetadata(ECConnectorMetadata): ...@@ -42,7 +42,7 @@ class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
self.mm_datas.append(mm_data) self.mm_datas.append(mm_data)
class ECSharedStorageConnector(ECConnectorBase): class ECExampleConnector(ECConnectorBase):
# NOTE: This is Simple debug implementation of the EC connector. # NOTE: This is Simple debug implementation of the EC connector.
# It save / load the EC cache to / from the disk. # It save / load the EC cache to / from the disk.
...@@ -76,7 +76,7 @@ class ECSharedStorageConnector(ECConnectorBase): ...@@ -76,7 +76,7 @@ class ECSharedStorageConnector(ECConnectorBase):
# Get the metadata # Get the metadata
metadata: ECConnectorMetadata = self._get_connector_metadata() metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata) assert isinstance(metadata, ECExampleConnectorMetadata)
assert encoder_cache is not None assert encoder_cache is not None
if metadata is None: if metadata is None:
logger.warning( logger.warning(
...@@ -160,7 +160,7 @@ class ECSharedStorageConnector(ECConnectorBase): ...@@ -160,7 +160,7 @@ class ECSharedStorageConnector(ECConnectorBase):
Args: Args:
scheduler_output (SchedulerOutput): the scheduler output object. scheduler_output (SchedulerOutput): the scheduler output object.
""" """
meta = ECSharedStorageConnectorMetadata() meta = ECExampleConnectorMetadata()
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items(): for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token)) meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
self._mm_datas_need_loads.clear() self._mm_datas_need_loads.clear()
......
...@@ -79,7 +79,7 @@ class ECConnectorFactory: ...@@ -79,7 +79,7 @@ class ECConnectorFactory:
# only load the files corresponding to the current connector. # only load the files corresponding to the current connector.
ECConnectorFactory.register_connector( ECConnectorFactory.register_connector(
"ECSharedStorageConnector", "ECExampleConnector",
"vllm.distributed.ec_transfer.ec_connector.shared_storage_connector", "vllm.distributed.ec_transfer.ec_connector.example_connector",
"ECSharedStorageConnector", "ECExampleConnector",
) )
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """Expert parallelism load balancer (EPLB)."""
Expert parallelism load balancer (EPLB).
"""
from .eplb_state import *
from .rebalance_algo import *
...@@ -45,7 +45,7 @@ from vllm.logger import init_logger ...@@ -45,7 +45,7 @@ from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts from vllm.model_executor.models.interfaces import MixtureOfExperts
from .async_worker import start_async_worker from .async_worker import start_async_worker
from .rebalance_algo import rebalance_experts from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -213,18 +213,23 @@ class EplbState: ...@@ -213,18 +213,23 @@ class EplbState:
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.device = device self.device = device
self.model_states: dict[str, EplbModelState] = {} self.model_states: dict[str, EplbModelState] = {}
self.policy: type[AbstractEplbPolicy] = DefaultEplbPolicy
"""
Selected EPLB algorithm class
"""
self.expert_load_window_step: int = 0
""" """
Current step in the sliding window. Current step in the sliding window.
Different from `expert_rearrangement_step`, Different from `expert_rearrangement_step`,
each EP rank may have its own `expert_load_window_step`. each EP rank may have its own `expert_load_window_step`.
""" """
self.expert_load_window_step: int = 0 self.expert_load_window_size: int = 0
""" """
Size of the expert load sliding window. Size of the expert load sliding window.
This is a constant and is taken from the config. This is a constant and is taken from the config.
""" """
self.expert_load_window_size: int = 0 self.expert_rearrangement_step: int = 0
""" """
Steps after last rearrangement. Steps after last rearrangement.
Will trigger a rearrangement if it exceeds the threshold. Will trigger a rearrangement if it exceeds the threshold.
...@@ -415,6 +420,10 @@ class EplbState: ...@@ -415,6 +420,10 @@ class EplbState:
) )
self.expert_rearrangement_step_interval = eplb_step_interval self.expert_rearrangement_step_interval = eplb_step_interval
# Set the policy based on the selected eplb algorithm type.
policy_type = self.parallel_config.eplb_config.policy
self.policy = EPLB_POLICIES[policy_type]
logger.debug("Selected EPLB policy: %d", policy_type)
if global_expert_load is not None: if global_expert_load is not None:
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
assert global_expert_load.shape == ( assert global_expert_load.shape == (
...@@ -441,7 +450,7 @@ class EplbState: ...@@ -441,7 +450,7 @@ class EplbState:
new_physical_to_logical_map, new_physical_to_logical_map,
new_logical_to_physical_map, new_logical_to_physical_map,
new_logical_replica_count, new_logical_replica_count,
) = rebalance_experts( ) = self.policy.rebalance_experts(
global_expert_load, global_expert_load,
num_replicas, num_replicas,
num_groups, num_groups,
...@@ -776,6 +785,7 @@ class EplbState: ...@@ -776,6 +785,7 @@ class EplbState:
f"{num_gpus=}, {num_nodes=}" f"{num_gpus=}, {num_nodes=}"
) )
# Get new expert mappings
for eplb_model_state, global_expert_load_window in zip( for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows self.model_states.values(), global_expert_load_windows
): ):
...@@ -784,7 +794,7 @@ class EplbState: ...@@ -784,7 +794,7 @@ class EplbState:
new_physical_to_logical_map, new_physical_to_logical_map,
new_logical_to_physical_map, new_logical_to_physical_map,
new_logical_replica_count, new_logical_replica_count,
) = rebalance_experts( ) = self.policy.rebalance_experts(
global_expert_load_window, global_expert_load_window,
num_replicas, num_replicas,
num_groups, num_groups,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import get_args
from vllm.config.parallel import EPLBPolicyOption
from .abstract import AbstractEplbPolicy
from .default import DefaultEplbPolicy
EPLB_POLICIES = {"default": DefaultEplbPolicy}
# Ensure that the EPLB_POLICIES keys match the EPLBPolicyOption values
assert set(EPLB_POLICIES.keys()) == set(get_args(EPLBPolicyOption))
__all__ = [
"AbstractEplbPolicy",
"DefaultEplbPolicy",
"EPLB_POLICIES",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
class AbstractEplbPolicy(ABC):
@classmethod
@abstractmethod
def rebalance_experts(
cls,
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_ranks: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics
for all logical experts
num_replicas: number of physical experts, must be a multiple of
`num_ranks`
num_groups: number of expert groups
num_nodes: number of server nodes
num_ranks: number of ranks, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [layers, num_replicas], the expert
index of each replica
logical_to_physical_map: [layers, num_logical_experts, X],
the replica indices for each expert
expert_count: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.
This module implements the core rearrangement algorithm.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""
import numpy as np
import torch
from .abstract import AbstractEplbPolicy
class DefaultEplbPolicy(AbstractEplbPolicy):
@classmethod
def balanced_packing(
cls, weight: torch.Tensor, num_packs: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
device = weight.device
if groups_per_pack == 1:
pack_index = torch.arange(
weight.size(-1), dtype=torch.int64, device=device
).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
return pack_index, rank_in_pack
weight_np = weight.cpu().numpy()
# Sort and get indices in decending order
indices_np = np.argsort(-weight_np, axis=-1)
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
# Run the packing algorithm
for i in range(num_layers):
pack_weights = [0.0] * num_packs
pack_items = [0] * num_packs
for group in indices_np[i]:
# Find a pack with capacity that has the lowest weight
pack = min(
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
key=pack_weights.__getitem__,
)
assert pack_items[pack] < groups_per_pack
pack_index_np[i, group] = pack
rank_in_pack_np[i, group] = pack_items[pack]
pack_weights[pack] += weight_np[i, group]
pack_items[pack] += 1
pack_index = torch.from_numpy(pack_index_np).to(device)
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
return pack_index, rank_in_pack
@classmethod
def replicate_experts(
cls, weight: torch.Tensor, num_phy: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt
@classmethod
def rebalance_experts_hierarchical(
cls,
weight: torch.Tensor,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
log2phy: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0
groups_per_node = num_groups // num_nodes
assert num_gpus % num_nodes == 0
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1,
perm,
torch.arange(
perm.size(1), dtype=torch.int64, device=perm.device
).expand(perm.shape),
)
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
group_pack_index, group_rank_in_pack = cls.balanced_packing(
tokens_per_group, num_nodes
)
log2mlog = (
(
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
).unsqueeze(-1)
+ torch.arange(
group_size, dtype=torch.int64, device=group_pack_index.device
)
).flatten(-2)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes
)
phy2mlog, phyrank, mlogcnt = cls.replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
pack_index, rank_in_pack = cls.balanced_packing(
tokens_per_phy, num_gpus // num_nodes
)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy
) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (
pphy2mlog.view(num_layers, num_nodes, -1)
+ torch.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device,
).view(1, -1, 1)
).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt
@classmethod
def rebalance_experts(
cls,
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_ranks: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all
logical experts
num_replicas: number of physical experts, must be a multiple of
`num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_ranks: number of ranks, must be a multiple of `num_nodes`
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
log2phy: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_ranks
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_ranks
)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1
),
)
return phy2log, log2phy, logcnt
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.
This module implements the core rearrangement algorithm.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""
import numpy as np
import torch
def balanced_packing(
weight: torch.Tensor, num_packs: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
device = weight.device
if groups_per_pack == 1:
pack_index = torch.arange(
weight.size(-1), dtype=torch.int64, device=device
).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
return pack_index, rank_in_pack
weight_np = weight.cpu().numpy()
# Sort and get indices in decending order
indices_np = np.argsort(-weight_np, axis=-1)
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
# Run the packing algorithm
for i in range(num_layers):
pack_weights = [0.0] * num_packs
pack_items = [0] * num_packs
for group in indices_np[i]:
# Find a pack with capacity that has the lowest weight
pack = min(
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
key=pack_weights.__getitem__,
)
assert pack_items[pack] < groups_per_pack
pack_index_np[i, group] = pack
rank_in_pack_np[i, group] = pack_items[pack]
pack_weights[pack] += weight_np[i, group]
pack_items[pack] += 1
pack_index = torch.from_numpy(pack_index_np).to(device)
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
return pack_index, rank_in_pack
def replicate_experts(
weight: torch.Tensor, num_phy: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt
def rebalance_experts_hierarchical(
weight: torch.Tensor,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g., NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map (torch.Tensor):
[num_moe_layers, num_physical_experts]
logical_to_physical_map (torch.Tensor):
[num_moe_layers, num_logical_experts, X]
logical_count (torch.Tensor):
[num_moe_layers, num_logical_experts]
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0
groups_per_node = num_groups // num_nodes
assert num_gpus % num_nodes == 0
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1,
perm,
torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(
perm.shape
),
)
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
log2mlog = (
(
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
).unsqueeze(-1)
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
).flatten(-2)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes
)
phy2mlog, phyrank, mlogcnt = replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy
) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (
pphy2mlog.view(num_layers, num_nodes, -1)
+ torch.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device,
).view(1, -1, 1)
).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt
def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all
logical experts
num_replicas: number of physical experts, must be a multiple of
`num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map:
[layers, num_replicas], the expert index of each replica
logical_to_physical_map:
[layers, num_logical_experts, X], the replica indices for each
expert
expert_count:
[layers, num_logical_experts], number of physical
replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus
)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1
),
)
return phy2log, log2phy, logcnt
__all__ = ["rebalance_experts"]
...@@ -322,9 +322,6 @@ async def transfer_layer( ...@@ -322,9 +322,6 @@ async def transfer_layer(
num_local_physical_experts = next(iter(expert_weights[0])).shape[0] num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
assert num_physical_experts == ep_size * num_local_physical_experts assert num_physical_experts == ep_size * num_local_physical_experts
# A buffer to hold the expert weights in one layer during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
num_local_experts=num_local_physical_experts, num_local_experts=num_local_physical_experts,
......
...@@ -144,9 +144,9 @@ class KVConnectorFactory: ...@@ -144,9 +144,9 @@ class KVConnectorFactory:
# only load the files corresponding to the current connector. # only load the files corresponding to the current connector.
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"SharedStorageConnector", "ExampleConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", "vllm.distributed.kv_transfer.kv_connector.v1.example_connector",
"SharedStorageConnector", "ExampleConnector",
) )
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
...@@ -190,3 +190,8 @@ KVConnectorFactory.register_connector( ...@@ -190,3 +190,8 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector", "vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
"DecodeBenchConnector", "DecodeBenchConnector",
) )
KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector",
"MooncakeConnector",
)
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
KV cache helper for store. KV cache helper for store.
""" """
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
import torch import torch
import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend
from vllm import _custom_ops as ops from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
...@@ -21,89 +22,6 @@ if TYPE_CHECKING: ...@@ -21,89 +22,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
class model_aware_kv_ops_helper:
def __init__(self, config: VllmConfig):
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
self.tp_size = config.parallel_config.tensor_parallel_size
def get_model_args(self, model_executable: torch.nn.Module):
model_config = model_executable.model.config
self.model_executable = model_executable
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/v1/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim", None)
if head_size is None:
head_size = int(hidden_size // num_attention_heads)
return num_heads, head_size
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
return key_cache, value_cache
def put_kv_to_cache(
self,
model_executable: torch.nn.Module,
keys,
values,
layer,
kv_cache,
slot_mapping,
start_pos,
end_pos,
):
model_config = model_executable.model.config
if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
k_c_normed_k_pe = keys.squeeze(1)
k_c_normed = k_c_normed_k_pe[:, : model_config.kv_lora_rank]
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank :]
ops.concat_and_cache_mla(
k_c_normed.to(kv_cache.device),
k_pe.to(kv_cache.device),
kv_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys.to(key_cache.device),
values.to(value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
def get_kv_connector_cache_layout(): def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# used for faster transfer. # used for faster transfer.
...@@ -266,3 +184,124 @@ def copy_kv_blocks( ...@@ -266,3 +184,124 @@ def copy_kv_blocks(
src_tensor = src_kv_caches[layer_name] src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name] dst_tensor = dst_kv_caches[layer_name]
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
@dataclass
class TpKVTopology:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_rank: int
remote_tp_size: dict[str, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: str
remote_block_size: dict[str, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
@property
def tp_size(self) -> int:
return self.remote_tp_size[self.engine_id]
@property
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
def tp_ratio(
self,
remote_tp_size: int,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
"""
Calculate the block size ratio between local and remote TP.
"""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
return self.block_size // remote_block_size
def tp_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> float:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: str) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
self,
remote_tp_size: int,
) -> int:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
def get_target_remote_rank_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)
...@@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC): ...@@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
return return
def register_cross_layers_kv_cache( def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
): ):
""" """
Initialize with a single KV cache tensor used by all layers. Initialize with a single KV cache tensor used by all layers.
...@@ -573,3 +573,17 @@ class KVConnectorBase_V1(ABC): ...@@ -573,3 +573,17 @@ class KVConnectorBase_V1(ABC):
expose connector transfer stats via Prometheus. expose connector transfer stats via Prometheus.
""" """
return None return None
def reset_cache(self) -> bool | None:
"""
Reset the connector's internal cache.
Returns:
bool: True if the cache was successfully reset, False otherwise.
"""
logger.debug(
"Connector cache reset requested, but %s does not implement reset_cache().",
type(self).__name__,
)
return None
...@@ -65,7 +65,7 @@ class ReqMeta: ...@@ -65,7 +65,7 @@ class ReqMeta:
@dataclass @dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata): class ExampleConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list) requests: list[ReqMeta] = field(default_factory=list)
def add_request( def add_request(
...@@ -81,7 +81,7 @@ class SharedStorageConnectorMetadata(KVConnectorMetadata): ...@@ -81,7 +81,7 @@ class SharedStorageConnectorMetadata(KVConnectorMetadata):
) )
class SharedStorageConnector(KVConnectorBase_V1): class ExampleConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector. # NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk. # It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU # It does extra work which will overwrite the existing prefix-cache in GPU
...@@ -157,7 +157,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -157,7 +157,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# Get the metadata # Get the metadata
metadata: KVConnectorMetadata = self._get_connector_metadata() metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata) assert isinstance(metadata, ExampleConnectorMetadata)
if metadata is None: if metadata is None:
logger.warning( logger.warning(
...@@ -241,7 +241,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -241,7 +241,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...] return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata) assert isinstance(connector_metadata, ExampleConnectorMetadata)
for request in connector_metadata.requests: for request in connector_metadata.requests:
if request.is_store: if request.is_store:
filename = self._generate_filename_debug( filename = self._generate_filename_debug(
...@@ -315,7 +315,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -315,7 +315,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
Args: Args:
scheduler_output (SchedulerOutput): the scheduler output object. scheduler_output (SchedulerOutput): the scheduler output object.
""" """
meta = SharedStorageConnectorMetadata() meta = ExampleConnectorMetadata()
total_need_load = 0 total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs: for new_req in scheduler_output.scheduled_new_reqs:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import threading
import time
import uuid
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import msgspec
import numpy as np
import torch
import zmq
import zmq.asyncio
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
try:
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run VLLM with MooncakeTransferEngine."
) from e
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
EngineId = str
ReqId = str
TRANS_DONE = b"trans_done"
TRANS_ERROR = b"trans_error"
logger = init_logger(__name__)
class MooncakeAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
):
remote_hostname: str
remote_port: int
request_ids: list[ReqId]
kv_caches_base_addr: list[int]
block_ids: list[list[int]]
@dataclass
class RecvReqMeta:
local_block_ids: list[int]
remote_host: str
remote_port: int
@dataclass
class SendBlockMeta:
local_block_ids: list[int]
ready: threading.Event
expire_time: float = float("inf")
@dataclass
class SendReqMeta:
reqs: dict[ReqId, SendBlockMeta]
lock: threading.Lock
@dataclass
class FinishedSendReqSet:
set: set[ReqId]
lock: threading.Lock
@dataclass
class FinishedReceiveReqSet:
set: set[ReqId]
lock: asyncio.Lock
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
self.reqs_to_send: dict[ReqId, list[int]] = {}
def add_new_req(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
load_remote_cache: bool = True,
):
if load_remote_cache:
self.reqs_to_recv[request_id] = RecvReqMeta(
local_block_ids=local_block_ids,
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
)
else:
self.reqs_to_send[request_id] = local_block_ids
class MooncakeConnector(KVConnectorBase_V1):
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: MooncakeConnectorScheduler | None = (
MooncakeConnectorScheduler(vllm_config, self.engine_id)
)
self.connector_worker: MooncakeConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeConnector does not do layerwise saving."""
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""MooncakeConnector does not save explicitly."""
pass
def wait_for_save(self):
pass
class MooncakeConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.engine_id: EngineId = engine_id
self.side_channel_host = get_ip()
self.side_channel_port = get_mooncake_side_channel_port(vllm_config)
assert vllm_config.kv_transfer_config
self.kv_role = vllm_config.kv_transfer_config.kv_role
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[ReqId, list[int]] = {}
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens,
params,
)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
if count > 0:
return count, True
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens,
params,
)
if not params:
return
if params.get("do_remote_prefill"):
assert self.kv_role != "kv_producer"
if all(p in params for p in ("remote_host", "remote_port")):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
)
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer",
params,
)
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
elif params.get("do_remote_decode"):
# Add an empty list to worker to create event.
self._reqs_need_send[request.request_id] = []
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = MooncakeConnectorMetadata()
# Loop through scheduled reqs and convert to RecvReqMeta.
if self.kv_role != "kv_producer":
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
)
self._reqs_need_recv.clear()
if self.kv_role != "kv_consumer":
for req_id, block_ids in self._reqs_need_send.items():
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params={},
load_remote_cache=False,
)
self._reqs_need_send.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector request_finished, request_status=%s, "
"kv_transfer_params=%s",
request.status,
params,
)
if not params:
return False, None
if params.get("do_remote_prefill"):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
assert self.kv_role != "kv_producer"
self._reqs_need_recv[request.request_id] = (request, [])
params["do_remote_prefill"] = False
return False, None
if (
not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED
):
return False, None
assert self.kv_role != "kv_consumer"
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = len(block_ids) > 0
if delay_free_blocks:
self._reqs_need_send[request.request_id] = block_ids
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
class MooncakeConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id)
self.vllm_config = vllm_config
self.engine = TransferEngine()
self.hostname = get_ip()
ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", "rdma", "")
if ret_value != 0:
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
self.rpc_port = self.engine.get_rpc_port()
logger.debug(
"Mooncake Transfer Engine initialized at %s:%d",
self.hostname,
self.rpc_port,
)
# Mooncake handshake port.
self.side_channel_port: int = get_mooncake_side_channel_port(vllm_config)
self.engine_id: EngineId = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
self.num_blocks = 0
assert vllm_config.kv_transfer_config
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"num_workers", 10
)
self.kv_caches_base_addr: list[int] = []
self.device_kv_caches: dict[str, torch.Tensor] = {}
self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock())
# For kv_both, we will act both prefiller and decoder.
if self.kv_role != "kv_consumer":
# Background thread for sending kvcaches to D.
self._mooncake_sender_t: threading.Thread | None = None
# Background thread for processing new sending requests.
self._sender_executor = ThreadPoolExecutor(
max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender"
)
logger.debug(
"Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers
)
if self.kv_role != "kv_producer":
self.receiver_loop = asyncio.new_event_loop()
self._mooncake_receiver_t = threading.Thread(
target=self._receiver_loop, args=(self.receiver_loop,), daemon=True
)
self._mooncake_receiver_t.start()
logger.debug("Mooncake Decoder: start receiver thread")
self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet(
set(), threading.Lock()
)
self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet(
set(), asyncio.Lock()
)
self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.use_mla = self.model_config.use_mla
backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
use_mla=self.use_mla,
)
self.backend_name = backend.get_name()
self.kv_cache_layout = get_kv_cache_layout()
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
self._use_pallas = self.kv_topo._use_pallas
self.zmq_ctx = zmq.Context()
self.async_zmq_ctx = zmq.asyncio.Context()
self._encoder = msgspec.msgpack.Encoder()
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
def __del__(self):
self.shutdown()
def shutdown(self):
"""Cleanup background threads on destruction."""
self.zmq_ctx.term()
self.async_zmq_ctx.term()
if self.kv_role != "kv_consumer":
self._sender_executor.shutdown(wait=False)
if self._mooncake_sender_t:
self._mooncake_sender_t.join()
if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
self._mooncake_receiver_t.join()
def _receiver_loop(self, loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(loop)
loop.run_forever()
def _mooncake_sender(
self, ready_event: threading.Event, base_port: int, tp_rank: int
):
"""
Background thread that listens for Mooncake requests, dispatches them
to a thread pool, and sends acknowledgments upon completion.
"""
frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER)
logger.debug("Mooncake sender starting listening on path: %s", frontend_path)
backend_path = make_zmq_path("inproc", str(uuid.uuid4()))
backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL)
poller = zmq.Poller()
poller.register(frontend, zmq.POLLIN)
poller.register(backend, zmq.POLLIN)
ready_event.set()
try:
while True:
sockets = dict(poller.poll())
if frontend in sockets:
identity, _, metadata_bytes = frontend.recv_multipart()
self._sender_executor.submit(
self._sender_worker,
identity,
metadata_bytes,
backend_path,
)
if backend in sockets:
identity, status = backend.recv_multipart()
frontend.send_multipart((identity, b"", status))
except zmq.ContextTerminated:
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
except Exception as e:
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
finally:
frontend.close()
backend.close()
def _sender_worker(
self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str
):
status = TRANS_ERROR
try:
metadata = self._decoder.decode(metadata_bytes)
self.send_kv_to_decode(metadata)
status = TRANS_DONE
except Exception as e:
logger.error("Error processing Mooncake handshake: %s", e)
finally:
pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH)
try:
pusher.send_multipart((identity, status))
except zmq.ZMQError as e:
logger.warning(
"Internal error, maybe the server is shutting down. Error: %s",
e,
)
finally:
pusher.close()
def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
with self.reqs_need_send.lock:
for req_id in meta.request_ids:
send_meta = self.reqs_need_send.reqs.get(req_id)
if send_meta is None:
logger.warning("Request %s not found in reqs_need_send", req_id)
return
# Mark it as not expired. We will send it now.
send_meta.expire_time = float("inf")
send_reqs.append((req_id, send_meta))
self._send_blocks(send_reqs, meta)
with self.reqs_need_send.lock:
for req_id in meta.request_ids:
del self.reqs_need_send.reqs[req_id]
with self.finished_sending_reqs.lock:
self.finished_sending_reqs.set.update(meta.request_ids)
def _send_blocks(
self,
send_reqs: list[tuple[ReqId, SendBlockMeta]],
agent_meta: MooncakeAgentMetadata,
):
src_ptrs = []
dst_ptrs = []
lengths = []
local_base_addr = self.kv_caches_base_addr
remote_base_addr = agent_meta.kv_caches_base_addr
block_len = self.block_len
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
assert len(send_reqs) == len(agent_meta.block_ids)
for (req_id, send_meta), remote_block_ids in zip(
send_reqs, agent_meta.block_ids
):
send_meta.ready.wait()
num_remote_blocks = len(remote_block_ids)
if num_remote_blocks == 0:
continue
local_block_ids = send_meta.local_block_ids
# Partial prefix cache hit: just read uncomputed blocks.
num_local_blocks = len(local_block_ids)
assert num_local_blocks >= num_remote_blocks
if num_local_blocks > num_remote_blocks:
local_block_ids = local_block_ids[-num_remote_blocks:]
# Group by indices
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
local_block_ids, remote_block_ids
)
for local_layer_addr, remote_layer_addr in zip(
local_base_addr, remote_base_addr
):
for group_local_block_id, group_remote_block_id in zip(
group_local_block_ids, group_remote_block_ids
):
src_ptrs.append(
local_layer_addr + group_local_block_id[0] * block_len
)
dst_ptrs.append(
remote_layer_addr + group_remote_block_id[0] * block_len
)
lengths.append(block_len * len(group_local_block_id))
logger.debug(
"Sending kv_caches for request %s (%d blocks) to %s",
req_id,
num_remote_blocks,
remote_session,
)
start_time = time.perf_counter()
ret_value = self.engine.batch_transfer_sync_write(
remote_session, src_ptrs, dst_ptrs, lengths
)
if ret_value != 0:
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
logger.debug(
"Sending to %s done, took %s",
remote_session,
time.perf_counter() - start_time,
)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in mooncake."""
logger.info("Registering KV_Caches. use_mla: %s", self.use_mla)
kv_data_ptrs = []
kv_data_lens = []
seen_base_addresses = []
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None
for layer_name, cache_or_caches in kv_caches.items():
logger.debug(
"registering layer %s with shape %s", layer_name, cache_or_caches.shape
)
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
base_addr = cache.data_ptr()
if base_addr in seen_base_addresses:
continue
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.nbytes
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size"
)
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
assert self.block_size == kernel_block_size
kv_data_ptrs.append(base_addr)
kv_data_lens.append(tensor_size_bytes)
self.kv_caches_base_addr = seen_base_addresses
ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens)
if ret_value != 0:
raise RuntimeError("Mooncake batch memory registration failed.")
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.device_kv_caches = kv_caches
logger.debug(
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
)
# No need to launch server for D node.
if self.kv_role == "kv_consumer":
return
ready_event = threading.Event()
self._mooncake_sender_t = threading.Thread(
target=self._mooncake_sender,
args=(ready_event, self.side_channel_port, self.tp_rank),
daemon=True,
name="mooncake_sender",
)
self._mooncake_sender_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
async with self.finished_recving_reqs.lock:
finished_recving_reqs = self.finished_recving_reqs.set
self.finished_recving_reqs.set = set()
return finished_recving_reqs
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
"""
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
fut = None
if self.kv_role != "kv_producer":
fut = asyncio.run_coroutine_threadsafe(
self.fetch_finished_recving_reqs(), self.receiver_loop
)
if self.kv_role != "kv_consumer":
with self.finished_sending_reqs.lock:
finished_sending_reqs = self.finished_sending_reqs.set
self.finished_sending_reqs.set = set()
else:
finished_sending_reqs = set()
finished_recving_reqs = fut.result() if fut else set()
if finished_sending_reqs or finished_recving_reqs:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving",
self.tp_rank,
len(finished_sending_reqs),
len(finished_recving_reqs),
)
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
with self.reqs_need_send.lock:
expired_reqs = [
req_id
for req_id, send_meta in self.reqs_need_send.reqs.items()
if send_meta.expire_time < now
]
for req_id in expired_reqs:
logger.warning(
"Request %s timed out after %d seconds without "
"being sent. Freeing its blocks on the producer side.",
req_id,
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
)
del self.reqs_need_send.reqs[req_id]
if expired_reqs:
finished_sending_reqs.update(expired_reqs)
return finished_sending_reqs or None, finished_recving_reqs or None
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
req_ids, block_ids = map(list, zip(*req_blocks))
metadata = MooncakeAgentMetadata(
remote_hostname=self.hostname,
remote_port=self.rpc_port,
request_ids=req_ids,
kv_caches_base_addr=self.kv_caches_base_addr,
block_ids=block_ids,
)
encoded_data = self._encoder.encode(metadata)
logger.debug(
"Size of encoded MooncakeAgentMetadata: %d bytes", len(encoded_data)
)
logger.debug("Sending kv transfer request for %s on path: %s", req_ids, path)
# Send query for the request.
sock: zmq.asyncio.Socket = make_zmq_socket(
self.async_zmq_ctx, path, zmq.REQ, bind=False, linger=0
)
sock.setsockopt(zmq.RCVTIMEO, 60000)
try:
await sock.send(encoded_data)
ret_msg = await sock.recv()
if ret_msg != TRANS_DONE:
logger.error(
"Error happens during tranfering kvcache for %s, see logs in prefiller.", # noqa: E501
req_ids,
)
return
except zmq.ContextTerminated:
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
except Exception as e:
logger.error("MooncakeAgentMetadata transfer failed for %s: %s", req_ids, e)
return
finally:
sock.close()
async with self.finished_recving_reqs.lock:
self.finished_recving_reqs.set.update(req_ids)
logger.debug("pulling kv_caches for %s finished", req_ids)
def group_kv_pull(self, metadata: MooncakeConnectorMetadata):
kv_pulls = defaultdict(list)
for req_id, meta in metadata.reqs_to_recv.items():
logger.debug(
"start_load_kv for request %s from remote engine. "
"Num local_block_ids: %s.",
req_id,
len(meta.local_block_ids),
)
path = make_zmq_path(
"tcp", meta.remote_host, meta.remote_port + self.tp_rank
)
kv_pulls[path].append((req_id, meta.local_block_ids))
return kv_pulls
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
if self.kv_role != "kv_producer":
kv_pulls = self.group_kv_pull(metadata)
for path, req_blocks in kv_pulls.items():
asyncio.run_coroutine_threadsafe(
self.receive_kv(path, req_blocks), self.receiver_loop
)
if self.kv_role != "kv_consumer":
with self.reqs_need_send.lock:
for req_id, block_ids in metadata.reqs_to_send.items():
if block_ids:
# Already gone through request_finished()
send_meta = self.reqs_need_send.reqs[req_id]
send_meta.local_block_ids = block_ids
send_meta.ready.set()
send_meta.expire_time = (
time.perf_counter()
+ envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
)
else:
# From update_state_after_alloc(),
# but not reach request_finished() yet
self.reqs_need_send.reqs[req_id] = SendBlockMeta(
local_block_ids=[], ready=threading.Event()
)
def group_concurrent_contiguous(
src_indices: list[int], dst_indices: list[int]
) -> tuple[list[list[int]], list[list[int]]]:
"""Vectorised NumPy implementation."""
if len(src_indices) == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
# This logic is now centralized
return (
envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)
...@@ -452,3 +452,7 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -452,3 +452,7 @@ class MultiConnector(KVConnectorBase_V1):
per_engine_labelvalues, per_engine_labelvalues,
prom_metrics, prom_metrics,
) )
def reset_cache(self) -> bool:
results = [c.reset_cache() is not False for c in self._connectors]
return all(results)
...@@ -20,10 +20,10 @@ import torch ...@@ -20,10 +20,10 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp, CopyBlocksOp,
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -55,10 +55,26 @@ if TYPE_CHECKING: ...@@ -55,10 +55,26 @@ if TYPE_CHECKING:
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
Transfer = tuple[int, float] # (xfer_handle, start_time) TransferHandle = int
EngineId = str EngineId = str
ReqId = str ReqId = str
#
# NIXL Connector Version
#
# Increment this version whenever there is an incompatible change to:
# - NixlAgentMetadata schema
# - kv_transfer_params schema or semantics
# - NIXL transfer protocol or wire format
# - KV cache memory layout or block organization
# - Any other change that breaks P/D interoperability
#
# Version History:
# 1: Initial version with compatibility checking
# 2: Add remote_request_id to kv_transfer_params
#
NIXL_CONNECTOR_VERSION: int = 2
GET_META_MSG = b"get_meta_msg" GET_META_MSG = b"get_meta_msg"
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -97,18 +113,95 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) ...@@ -97,18 +113,95 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
@dataclass @dataclass
class NixlAgentMetadata(KVConnectorHandshakeMetadata): class NixlAgentMetadata:
engine_id: str engine_id: str
agent_metadata: bytes agent_metadata: bytes
kv_caches_base_addr: list[int] kv_caches_base_addr: list[int]
device_id: int device_id: int
num_blocks: int num_blocks: int
block_lens: list[int] block_lens: list[int]
attn_backend_name: str
kv_cache_layout: str kv_cache_layout: str
block_size: int block_size: int
@dataclass
class NixlHandshakePayload(KVConnectorHandshakeMetadata):
"""
Wrapper for NIXL handshake sent over the wire.
Enables two-phase decoding for graceful compatibility checking:
1. Decode NixlHandshakePayload to get compatibility_hash
2. Compute local hash and compare
3. Only if hashes match, decode agent_metadata_bytes
This prevents decoder errors when NixlAgentMetadata schema is
incompatible, allowing graceful failure with clear error message.
"""
compatibility_hash: str
agent_metadata_bytes: bytes # NixlAgentMetadata encoded
def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str
) -> str:
"""
Compute compatibility hash for NIXL KV transfer.
Hash only the factors that affect whether two NIXL instances can
successfully transfer KV cache data.
Factors included:
- vLLM version and NIXL connector version
- Model architecture (name, dtype, KV heads, layers)
- KV cache format (dtype, sliding window)
- Attention backend
Note: Factors like tensor_parallel_size, block_size, and kv_cache_layout
are validated at runtime in _validate_remote_agent_handshake and are not
included in this hash to support heterogeneous deployments.
Note - the set of factors are likely to evolve significantly over
time to be more or less permissive.
Returns:
SHA-256 hex digest
"""
from vllm import __version__ as vllm_version
from vllm.config.utils import hash_factors
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
factors = {
# Version compatibility
"vllm_version": vllm_version,
"nixl_connector_version": NIXL_CONNECTOR_VERSION,
# Model architecture - affects KV cache shape
"model": model_config.model,
"dtype": str(model_config.dtype),
"num_kv_heads": model_config.get_total_num_kv_heads(),
"head_size": model_config.get_head_size(),
"num_hidden_layers": model_config.get_total_num_hidden_layers(),
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype),
}
compat_hash = hash_factors(factors)
logger.debug(
"NIXL compatibility hash: %s (model=%s, dtype=%s, num_kv_heads=%d, "
"cache_dtype=%s, attn_backend=%s)",
compat_hash,
factors["model"],
factors["dtype"],
factors["num_kv_heads"],
factors["cache_dtype"],
attn_backend_name,
)
return compat_hash
@dataclass @dataclass
class ReqMeta: class ReqMeta:
local_block_ids: list[int] local_block_ids: list[int]
...@@ -118,6 +211,7 @@ class ReqMeta: ...@@ -118,6 +211,7 @@ class ReqMeta:
remote_host: str remote_host: str
remote_port: int remote_port: int
remote_engine_id: str remote_engine_id: str
remote_request_id: str
tp_size: int tp_size: int
...@@ -144,6 +238,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): ...@@ -144,6 +238,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
local_physical_block_ids=local_block_ids, local_physical_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"], remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"], remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_request_id=kv_transfer_params["remote_request_id"],
remote_host=kv_transfer_params["remote_host"], remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"], remote_port=kv_transfer_params["remote_port"],
# P workers don't need to receive tp_size from proxy here. # P workers don't need to receive tp_size from proxy here.
...@@ -396,14 +491,14 @@ class NixlConnectorScheduler: ...@@ -396,14 +491,14 @@ class NixlConnectorScheduler:
encoded_data: dict[int, bytes] = {} encoded_data: dict[int, bytes] = {}
encoder = msgspec.msgpack.Encoder() encoder = msgspec.msgpack.Encoder()
for tp_rank, rank_metadata in metadata.items(): for tp_rank, rank_metadata in metadata.items():
if not isinstance(rank_metadata, NixlAgentMetadata): if not isinstance(rank_metadata, NixlHandshakePayload):
raise ValueError( raise ValueError(
"NixlConnectorScheduler expects NixlAgentMetadata for " "NixlConnectorScheduler expects NixlHandshakePayload for "
"handshake metadata." "handshake metadata."
) )
encoded_data[tp_rank] = encoder.encode(rank_metadata) encoded_data[tp_rank] = encoder.encode(rank_metadata)
logger.debug( logger.debug(
"Tp rank %d: encoded NixlAgentMetadata size: %s bytes", "Tp rank %d: encoded NixlHandshakePayload size: %s bytes",
tp_rank, tp_rank,
str(len(encoded_data[tp_rank])), str(len(encoded_data[tp_rank])),
) )
...@@ -530,7 +625,12 @@ class NixlConnectorScheduler: ...@@ -530,7 +625,12 @@ class NixlConnectorScheduler:
if params.get("remote_block_ids"): if params.get("remote_block_ids"):
if all( if all(
p in params p in params
for p in ("remote_engine_id", "remote_host", "remote_port") for p in (
"remote_engine_id",
"remote_request_id",
"remote_host",
"remote_port",
)
): ):
# If remote_blocks and num_external_tokens = 0, we have # If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call # a full prefix cache hit on the D worker. We need to call
...@@ -659,6 +759,7 @@ class NixlConnectorScheduler: ...@@ -659,6 +759,7 @@ class NixlConnectorScheduler:
do_remote_decode=False, do_remote_decode=False,
remote_block_ids=block_ids, remote_block_ids=block_ids,
remote_engine_id=self.engine_id, remote_engine_id=self.engine_id,
remote_request_id=request.request_id,
remote_host=self.side_channel_host, remote_host=self.side_channel_host,
remote_port=self.side_channel_port, remote_port=self.side_channel_port,
tp_size=self.vllm_config.parallel_config.tensor_parallel_size, tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
...@@ -668,128 +769,6 @@ class NixlConnectorScheduler: ...@@ -668,128 +769,6 @@ class NixlConnectorScheduler:
class NixlConnectorWorker: class NixlConnectorWorker:
"""Implementation of Worker side methods""" """Implementation of Worker side methods"""
@dataclass
class TpKVTopology:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_rank: int
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
)
@property
def tp_size(self) -> int:
return self.remote_tp_size[self.engine_id]
@property
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
def tp_ratio(
self,
remote_tp_size: int,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
"""
Calculate the block size ratio between local and remote TP.
"""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
return self.block_size // remote_block_size
def tp_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> float:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
self,
remote_tp_size: int,
) -> int:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
def get_target_remote_rank_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)
def __init__(self, vllm_config: VllmConfig, engine_id: str): def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None: if NixlWrapper is None:
logger.error("NIXL is not available") logger.error("NIXL is not available")
...@@ -904,7 +883,7 @@ class NixlConnectorWorker: ...@@ -904,7 +883,7 @@ class NixlConnectorWorker:
# In progress transfers. # In progress transfers.
# [req_id -> list[handle]] # [req_id -> list[handle]]
self._recving_metadata: dict[ReqId, ReqMeta] = {} self._recving_metadata: dict[ReqId, ReqMeta] = {}
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) self._recving_transfers = defaultdict[ReqId, list[TransferHandle]](list)
# Track the expiration time of requests that are waiting to be sent. # Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {} self._reqs_to_send: dict[ReqId, float] = {}
# Set of requests that have been part of a batch, regardless of status. # Set of requests that have been part of a batch, regardless of status.
...@@ -916,7 +895,7 @@ class NixlConnectorWorker: ...@@ -916,7 +895,7 @@ class NixlConnectorWorker:
self._failed_recv_reqs: set[ReqId] = set() self._failed_recv_reqs: set[ReqId] = set()
# Handshake metadata of this worker for NIXL transfers. # Handshake metadata of this worker for NIXL transfers.
self.xfer_handshake_metadata: NixlAgentMetadata | None = None self.xfer_handshake_metadata: NixlHandshakePayload | None = None
# Background thread for initializing new NIXL handshakes. # Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor( self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker. # NIXL is not guaranteed to be thread-safe, limit 1 worker.
...@@ -951,6 +930,13 @@ class NixlConnectorWorker: ...@@ -951,6 +930,13 @@ class NixlConnectorWorker:
logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout) logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name
)
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to # With heterogeneous TP, P must wait for all assigned D TP workers to
...@@ -958,7 +944,7 @@ class NixlConnectorWorker: ...@@ -958,7 +944,7 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats() self.xfer_stats = NixlKVConnectorStats()
self.kv_topo = self.TpKVTopology( self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
engine_id=self.engine_id, engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state remote_tp_size=self._tp_size, # shared state
...@@ -999,14 +985,58 @@ class NixlConnectorWorker: ...@@ -999,14 +985,58 @@ class NixlConnectorWorker:
# Set receive timeout to 5 seconds to avoid hanging on dead server # Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(msg) sock.send(msg)
metadata_bytes = sock.recv() handshake_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes) # Decode handshake payload to get compatibility hash
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try:
handshake_payload = handshake_decoder.decode(handshake_bytes)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
raise RuntimeError(
f"Failed to decode NixlHandshakePayload. This likely indicates "
f"an incompatibility between connector version. Error: {e}"
) from e
got_metadata_time = time.perf_counter() got_metadata_time = time.perf_counter()
logger.debug( logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time "NIXL handshake: get metadata took: %s", got_metadata_time - start_time
) )
# Check compatibility hash BEFORE decoding agent metadata
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
):
raise RuntimeError(
f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible configurations. "
f"This may be due to: different vLLM versions, models, dtypes, "
f"KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
)
logger.info(
"NIXL compatibility check passed (hash: %s)",
handshake_payload.compatibility_hash,
)
# Decode agent metadata
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
try:
metadata = metadata_decoder.decode(
handshake_payload.agent_metadata_bytes
)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
# This should not happen if hash matched
raise RuntimeError(
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Ensure engine id matches. # Ensure engine id matches.
if metadata.engine_id != expected_engine_id: if metadata.engine_id != expected_engine_id:
raise RuntimeError( raise RuntimeError(
...@@ -1180,14 +1210,11 @@ class NixlConnectorWorker: ...@@ -1180,14 +1210,11 @@ class NixlConnectorWorker:
# Enable different block lengths for different layers when MLA is used. # Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]() self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms self.slot_size_per_layer = list[int]() # HD bytes in kv terms
self.device_id = self.tp_rank
for layer_name, cache_or_caches in xfer_buffers.items(): for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list: for cache in cache_list:
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
if not self.use_host_buffer and current_platform.is_cuda_alike():
self.device_id = cache.device.index
if base_addr in seen_base_addresses: if base_addr in seen_base_addresses:
continue continue
...@@ -1230,8 +1257,7 @@ class NixlConnectorWorker: ...@@ -1230,8 +1257,7 @@ class NixlConnectorWorker:
"All kv cache tensors must have the same size" "All kv cache tensors must have the same size"
) )
# Need to make sure the device ID is non-negative for NIXL, # Need to make sure the device ID is non-negative for NIXL,
# Torch uses -1 to indicate CPU tensors while NIXL uses explicit # Torch uses -1 to indicate CPU tensors.
# memory type.
self.device_id = max(cache.get_device(), 0) self.device_id = max(cache.get_device(), 0)
caches_data.append( caches_data.append(
(base_addr, curr_tensor_size_bytes, self.device_id, "") (base_addr, curr_tensor_size_bytes, self.device_id, "")
...@@ -1297,19 +1323,24 @@ class NixlConnectorWorker: ...@@ -1297,19 +1323,24 @@ class NixlConnectorWorker:
assert len(self.block_window_per_layer) == self.num_layers assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections. # After KV Caches registered, listen for new connections.
self.xfer_handshake_metadata = NixlAgentMetadata( agent_metadata = NixlAgentMetadata(
engine_id=self.engine_id, engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(), agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
device_id=self.device_id, device_id=self.device_id,
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer, block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout kv_cache_layout=self.kv_cache_layout
if not self.use_host_buffer if not self.use_host_buffer
else self.host_buffer_kv_cache_layout, else self.host_buffer_kv_cache_layout,
block_size=self.block_size, block_size=self.block_size,
) )
# Wrap metadata in payload with hash for defensive decoding
encoder = msgspec.msgpack.Encoder()
self.xfer_handshake_metadata = NixlHandshakePayload(
compatibility_hash=self.compat_hash,
agent_metadata_bytes=encoder.encode(agent_metadata),
)
def register_local_xfer_handler( def register_local_xfer_handler(
self, self,
...@@ -1524,8 +1555,6 @@ class NixlConnectorWorker: ...@@ -1524,8 +1555,6 @@ class NixlConnectorWorker:
remote_engine_id = nixl_agent_meta.engine_id remote_engine_id = nixl_agent_meta.engine_id
assert self._tp_size[remote_engine_id] == remote_tp_size assert self._tp_size[remote_engine_id] == remote_tp_size
# TODO We may eventually want to skip enforcing the same attn backend.
assert nixl_agent_meta.attn_backend_name == self.backend_name
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
...@@ -1818,9 +1847,7 @@ class NixlConnectorWorker: ...@@ -1818,9 +1847,7 @@ class NixlConnectorWorker:
self._reqs_to_send.pop(req_id, None) self._reqs_to_send.pop(req_id, None)
return notified_req_ids return notified_req_ids
def _pop_done_transfers( def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
self, transfers: dict[str, list[tuple[int, float]]]
) -> set[str]:
""" """
Pop completed xfers by checking for DONE state. Pop completed xfers by checking for DONE state.
Args: Args:
...@@ -1831,7 +1858,7 @@ class NixlConnectorWorker: ...@@ -1831,7 +1858,7 @@ class NixlConnectorWorker:
done_req_ids: set[str] = set() done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()): for req_id, handles in list(transfers.items()):
in_progress = False in_progress = False
for handle, xfer_start_time in handles: for handle in handles:
try: try:
xfer_state = self.nixl_wrapper.check_xfer_state(handle) xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE": if xfer_state == "DONE":
...@@ -1946,6 +1973,7 @@ class NixlConnectorWorker: ...@@ -1946,6 +1973,7 @@ class NixlConnectorWorker:
self._read_blocks( self._read_blocks(
request_id=req_id, request_id=req_id,
dst_engine_id=meta.remote_engine_id, dst_engine_id=meta.remote_engine_id,
remote_request_id=meta.remote_request_id,
local_block_ids=meta.local_physical_block_ids, local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote_block_ids, remote_block_ids=meta.remote_block_ids,
) )
...@@ -1956,6 +1984,7 @@ class NixlConnectorWorker: ...@@ -1956,6 +1984,7 @@ class NixlConnectorWorker:
remote_block_ids: list[int], remote_block_ids: list[int],
dst_engine_id: str, dst_engine_id: str,
request_id: str, request_id: str,
remote_request_id: str,
): ):
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1: if block_size_ratio > 1:
...@@ -1988,7 +2017,7 @@ class NixlConnectorWorker: ...@@ -1988,7 +2017,7 @@ class NixlConnectorWorker:
# Number of D TP workers that will read from dst P. Propagate tp_ratio # Number of D TP workers that will read from dst P. Propagate tp_ratio
# on notification so that dst worker can wait before freeing blocks. # on notification so that dst worker can wait before freeing blocks.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id) tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
notif_id = f"{request_id}:{tp_ratio}".encode() notif_id = f"{remote_request_id}:{tp_ratio}".encode()
# Full prefix cache hit: do not need to read remote blocks, # Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need. # just notify P worker that we have the blocks we need.
...@@ -2096,7 +2125,7 @@ class NixlConnectorWorker: ...@@ -2096,7 +2125,7 @@ class NixlConnectorWorker:
self.nixl_wrapper.transfer(handle) self.nixl_wrapper.transfer(handle)
# Use handle to check completion in future step(). # Use handle to check completion in future step().
self._recving_transfers[request_id].append((handle, time.perf_counter())) self._recving_transfers[request_id].append(handle)
except Exception: except Exception:
logger.exception( logger.exception(
"NIXL transfer setup/initiation failed for request %s. " "NIXL transfer setup/initiation failed for request %s. "
...@@ -2227,7 +2256,7 @@ class NixlConnectorWorker: ...@@ -2227,7 +2256,7 @@ class NixlConnectorWorker:
"""Shutdown the connector worker.""" """Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False) self._handshake_initiation_executor.shutdown(wait=False)
for handles in self._recving_transfers.values(): for handles in self._recving_transfers.values():
for handle, _ in handles: for handle in handles:
self.nixl_wrapper.release_xfer_handle(handle) self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear() self._recving_transfers.clear()
if self.src_xfer_side_handle: if self.src_xfer_side_handle:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `KVLookupBufferBase` that allows developers to
think of KV cache operations as inserting new KV cache entries (`insert`)
into the lookup buffer and querying existing KV caches (`drop_select`)
from the lookup buffer.
This file also contains a new class `KVStoreBufferBase` that allows developers
to manage the KVCache buffer as a simple key-value storage buffer with basic
put/get operations.
These classes above are abstracted behind class `KVCacheBufferBase`.
"""
from abc import ABC, abstractmethod
import torch
class KVCacheBufferBase(ABC):
"""
Abstract base class for a KVCache buffer.
"""
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
KVCache buffer when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
class KVLookupBufferBase(KVCacheBufferBase):
"""
Abstract base class for a KVCache lookup buffer.
This class provides an abstraction for a key-value (KV) cache lookup buffer.
The key of the lookup buffer:
- input_tokens: token IDs of the request
- roi: a binary mask on top of input_tokens.
- Purpose of roi: Since KV cache may only be available for a subset of
tokens in the input (for example, when vLLM is connected to an external
KV cache service), roi specifies the subset of tokens that the KV cache
is associated with.
- NOTE: roi can be further extended to describe which part of KV the
current process is holding (each process may only hold a part of KV
due to TP and PP). This is not implemented for now.
The value of the lookup buffer:
- key: the key tensor in the KV cache
- value: the value tensor in the KV cache
- hidden: the final hidden state generated by model forwarding. This allows
vLLM to bypass further model forwarding by transmitting the hidden state.
"""
@abstractmethod
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
) -> None:
"""Insert into the lookup buffer.
The functionality is similar to the following python statement
```
buffer[input_tokens, roi] = [key, value, hidden]
```
FIXME: in the future, we should only have two arguments, key and value,
where key is a tensor dict and value is a tensor dict.
FIXME: we should transmit both sampler outputs and the hidden states.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
key (torch.Tensor): The key tensor in the KV cache.
value (torch.Tensor): The value tensor in the KV cache.
hidden (torch.Tensor): The final hidden state tensor generated
during model forwarding to bypass model
forwarding.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def drop_select(
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
) -> list[torch.Tensor | None]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
If `input_tokens` and `roi` is `None`, it means selecting any of the
KV caches in the buffer, return, and remove it from the buffer, useful
when offloading KV cache to KV cache storage service.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
list[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
class KVStoreBufferBase(KVCacheBufferBase):
"""
Abstract base class for a KVCache storage buffer with key-value semantics.
This class provides a simple key-value storage buffer abstract with basic
put/get operations, which enables flexible KVCache transfer granular
control.
The functionality is similar to a distributed key-value store, where:
- Key: A unique string identifier for the cached entry
- Value:
- Tensor to be stored and retrieved
- None (indicating deletion or empty value)
"""
@abstractmethod
def put(
self,
key: str,
value: torch.Tensor | None,
) -> None:
"""Store a key-value pair in the buffer.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
value (Optional[torch.Tensor]): Tensor to be stored.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def get(
self,
key: str,
) -> torch.Tensor | None:
"""Retrieve a value from the buffer by key.
Args:
key (str): Unique identifier for a tensor, this tensor could be the
key cache tensor, value cache tensor, or hidden state tensor
generated during model forwarding.
Returns:
Optional[torch.Tensor]: Stored tensor if exists, None otherwise.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains a new class `MooncakeStore` that allows developers to
think of KV cache transfer operations as putting new KV cache entries
into a remote KVStore-based lookup buffer and getting existing KV caches
from this remote lookup buffer.
"""
import json
import os
from dataclasses import dataclass
import torch
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVStoreBufferBase
from vllm.logger import init_logger
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
logger = init_logger(__name__)
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
"""Load the config from a JSON file."""
with open(file_path) as fin:
config = json.load(fin)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=config.get(
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
),
local_buffer_size=config.get(
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
)
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
return MooncakeStoreConfig.from_file(config_file_path)
class MooncakeStore(KVStoreBufferBase):
def __init__(
self,
config: VllmConfig,
):
try:
from mooncake.store import MooncakeDistributedStore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector."
) from e
try:
self.store = MooncakeDistributedStore()
self.config = MooncakeStoreConfig.load_from_env()
logger.info("Mooncake Configuration loaded successfully.")
self.store.setup(
self.config.local_hostname,
self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
)
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error("An error occurred while loading the configuration: %s", exc)
raise
def close(self):
# MooncakeDistributedStore will automatically call the destructor, so
# it is unnecessary to close it manually.
pass
def put(
self,
key: str,
value: torch.Tensor | None,
) -> None:
# A message queue needs to be introduced before making it asynchronous.
if value is not None:
self._put_impl(key, value)
def get(
self,
key: str,
) -> torch.Tensor | None:
# A message queue needs to be introduced before making it asynchronous.
value = self._get_impl(key)
return value
def _put_impl(
self,
key: str,
value: torch.Tensor,
) -> None:
"""Put KVCache to Mooncake Store"""
device_id = value.device.index if value.device.type == "cuda" else -1
device_tensor = torch.tensor(device_id, dtype=torch.int32)
value_bytes = safetensors_save({"tensor": value, "device_id": device_tensor})
try:
self.store.put(key, value_bytes)
except TypeError as err:
logger.error("Failed to put value into Mooncake Store: %s", err)
raise TypeError("Mooncake Store Put Type Error.") from err
def _get_impl(
self,
key: str,
) -> torch.Tensor | None:
"""Get KVCache from Mooncake Store"""
try:
data = self.store.get(key)
except TypeError as err:
logger.error("Failed to get value from Mooncake Store: %s", err)
raise TypeError("Mooncake Store Get Type Error.") from err
if data:
loaded_tensors = safetensors_load(data)
tensor = loaded_tensors["tensor"]
device_id_tensor = loaded_tensors["device_id"]
device_id = int(device_id_tensor.item())
device = (
torch.device("cuda", device_id)
if device_id >= 0
else torch.device("cpu")
)
return tensor.to(device)
return None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Implements a distributed key-value (KV) cache transfer mechanism.
Key Features:
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import threading
from collections import deque
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVLookupBufferBase
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class SimpleBuffer(KVLookupBufferBase):
def __init__(
self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float
):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""
self.buffer: deque[list[torch.Tensor]] = deque()
self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_cv = threading.Condition()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: threading.Thread | None = None
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
def _matches(
self,
tokens_roi_sender: list[torch.Tensor],
tokens_roi_recver: list[torch.Tensor],
):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
tokens_sender = tokens_roi_sender[0]
tokens_recver = tokens_roi_recver[0]
roi_sender = tokens_roi_sender[1]
roi_recver = tokens_roi_recver[1]
if tokens_recver is None:
# consumer sends an empty request
# semantics: DROP SELECT * LIMIT 1
# so any of the data in the buffer can be drop-selected
return True
# Assuming that roi is a binary mask on tokens
tokens_sender = tokens_sender[roi_sender]
tokens_recver = tokens_recver[roi_recver]
# simple common prefix matching
min_length = min(len(tokens_sender), len(tokens_recver))
if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]):
return min_length
return 0
def _send_tensor_and_dec_size(self, tensor: torch.Tensor | None) -> None:
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
if tensor.dtype == torch.bool:
tensor = tensor.float()
self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: list | torch.Tensor | None):
if isinstance(data, torch.Tensor):
return data.element_size() * data.numel()
if not data:
# cannot perform `not data` on a tensor
# so this check needs to go after the check above
return 0
raise AssertionError(f"Unknown data type {type(data)}")
def _add_to_buffer(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
):
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone()
if isinstance(key, torch.Tensor):
key = key.clone()
if isinstance(value, torch.Tensor):
value = value.clone()
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden]
data_size = sum([self._get_element_size(data) for data in buffer_item])
with self.buffer_cv:
if self.buffer_size + data_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size + data_size > self.buffer_size_threshold:
self.buffer_cv.wait()
self.buffer_size += data_size
self.buffer.append(buffer_item)
self.buffer_cv.notify()
def _is_end_signal(self, signal):
return signal is None
def drop_select_handler(self):
try:
while True:
signal = self.signal_pipe.recv_tensor()
if self._is_end_signal(signal):
logger.info("Received end signal!")
break
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
assert roi is not None, (
"Please provide the roi when sending drop-select request"
)
roi = roi > 0.5
tokens_roi_recver = [input_tokens, roi]
def is_buffer_available(
tokens_roi_recver: list[torch.Tensor],
) -> bool:
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
for _ in range(len(self.buffer)):
if self._matches(self.buffer[0], tokens_roi_recver) > 0:
return True
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
return False
with self.buffer_cv:
while not is_buffer_available(tokens_roi_recver):
logger.debug("KV transfer buffer is not available. Waiting...")
self.buffer_cv.wait()
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
self.buffer_cv.notify()
except RuntimeError as e:
if "Connection closed by peer" not in str(e):
raise e
logger.debug("Closing drop_select_handler")
def drop_select(
self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None
) -> list[torch.Tensor | None]:
assert self.request_handling_thread is None, (
"drop_select should be called by the KV cache consumer "
"(e.g. the decode vLLM instance)"
)
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone().float()
self.signal_pipe.send_tensor(self.normal_signal)
self.data_pipe.send_tensor(input_tokens)
self.data_pipe.send_tensor(roi)
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = roi > 0.5
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()
return [input_tokens, roi, key, value, hidden]
def insert(
self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
) -> None:
self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
if self.request_handling_thread is None:
self.request_handling_thread = threading.Thread(
target=self.drop_select_handler
)
self.request_handling_thread.start()
def close(self):
if (
hasattr(self, "request_handling_thread")
and self.request_handling_thread is not None
):
self.request_handling_thread.join()
else:
# TODO: have a explicit close signal and have a explicit way to
# check if it's requester
self.signal_pipe.send_tensor(self.end_signal)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file defines an interface `KVPipeBase`
that provides an abstraction for sending and receiving tensors, or None, via
distributed communications.
All classes instantiated from this interface are assumed to be a FIFO pipe.
If your distributed communication platform already supports key-value lookup,
you can bypass this interface and directly start from `kv_lookup_buffer`.
"""
from abc import ABC, abstractmethod
import torch
class KVPipeBase(ABC):
"""
This class provides an interface for sending and receiving tensors, or
None, by distributed communications.
"""
@abstractmethod
def send_tensor(self, tensor: torch.Tensor | None) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
TODO: add a `key` argument so that we can use traditional
key-value database as the distributed communication mechanism behind
the pipe.
Args:
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def recv_tensor(self) -> torch.Tensor | None:
"""Receive a tensor (can be None) from the pipeline.
Returns:
Optional[torch.Tensor]: The tensor received from the pipeline. Can
be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the pipeline and release resources.
This method is responsible for closing the communication pipeline
and releasing any resources associated with it.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment