Unverified Commit cc3993b0 authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

nixl refactor [2/N]: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology (#39529)


Signed-off-by: default avatarZhanqiu Hu <zhu@redhat.com>
parent 50dd4cb4
......@@ -631,7 +631,7 @@ def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes():
mock_thread.return_value.is_alive.return_value = False
worker.use_mla = True
worker.kv_topo.is_mla = True
worker.transfer_topo.is_mla = True
# MLA cache tensor: shape[-2] is the block size.
mla_cache = torch.zeros((2, 16, 96), dtype=torch.float16)
......@@ -692,9 +692,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
# Override TP rank/size to simulate P TP=2
prefill_worker.tp_rank = P_TP_RANK
prefill_worker.tp_size = P_TP_SIZE
# Update shared dict so kv_topo sees correct TP size
prefill_worker._tp_size[prefill_worker.engine_id] = P_TP_SIZE
prefill_worker.kv_topo.tp_rank = P_TP_RANK
prefill_worker.transfer_topo.tp_rank = P_TP_RANK
prefill_worker.transfer_topo.tp_size = P_TP_SIZE
prefill_worker.kv_caches_base_addr = [0x1000]
prefill_worker.block_len_per_layer = [local_block_len]
......@@ -714,7 +714,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
send_meta.ready.set()
# Compute target D ranks using the production code path
target_d_ranks = prefill_worker.kv_topo.get_target_remote_ranks(d_tp_size)
target_d_ranks = prefill_worker.transfer_topo.handshake_target_ranks(d_tp_size)
mock_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_socket.send_multipart = AsyncMock()
......
......@@ -21,7 +21,7 @@ from vllm import LLM
from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator,
TpKVTopology,
TransferTopology,
get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl
......@@ -463,19 +463,20 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
test_shape = self.attn_backends[0].get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
self.kv_topo = TpKVTopology(
self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank,
tp_size=self.world_size,
block_size=self.block_size,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
is_mamba=False,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=self.attn_backends,
tensor_shape=test_shape,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks
)
def _nixl_handshake(
......@@ -496,7 +497,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
# Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation.
remote_block_lens = list(self.block_len_per_layer)
tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size)
if remote_tp_size > self.world_size:
# P TP > D TP case, block_len of remote is smaller
remote_block_lens = [
......@@ -731,8 +732,9 @@ class TestNixlHandshake:
assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
remote_info = worker.transfer_topo.get_engine_info(remote_engine_id)
assert remote_info.remote_tp_size == remote_tp_size
assert -tp_ratio == worker.transfer_topo.tp_ratio(remote_tp_size)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
......@@ -796,7 +798,7 @@ class TestNixlHandshake:
(conn_p0.connector_worker, conn_p1.connector_worker)
):
worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
worker.transfer_topo.tp_size = p_tp_size
worker.tp_rank = rank
worker.use_mla = True
......@@ -2337,7 +2339,7 @@ def test_compatibility_hash_validation(
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
decode_worker.transfer_topo.cross_layers_blocks,
)
prefill_block_size = config_overrides.get("block_size", 16)
......@@ -2424,12 +2426,13 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
test_shape = backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
decode_worker.kv_topo = TpKVTopology(
decode_worker.transfer_topo = TransferTopology(
tp_rank=decode_worker.tp_rank,
tp_size=decode_worker.world_size,
block_size=decode_worker.block_size,
engine_id=decode_worker.engine_id,
remote_tp_size=decode_worker._tp_size, # shared state
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla,
is_mamba=False,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backends=[backend],
tensor_shape=test_shape,
......@@ -2438,7 +2441,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
decode_worker.transfer_topo.cross_layers_blocks,
)
if error_scenario == "handshake_decode_error":
......
......@@ -152,13 +152,14 @@ def test_read_blocks_for_req_expands_remote_ids(
remote_engine_id = "remote-engine"
if has_mamba:
worker._mamba_phys_ratio = {remote_engine_id: remote_ratio}
worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio}
# Mock kv_topo: empty remote ranks skips the transfer machinery entirely,
# isolating the block-ID expansion logic.
worker.kv_topo = MagicMock()
worker.kv_topo.get_target_remote_ranks_from_engine_id.return_value = []
worker.kv_topo.tp_ratio_from_engine_id.return_value = 1
# Mock transfer_topo: empty remote ranks skips the transfer machinery
# entirely, isolating the block-ID expansion logic.
worker.transfer_topo = MagicMock()
worker.transfer_topo.target_remote_ranks.return_value = []
worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1)
worker.transfer_topo.tp_ratio.return_value = 1
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
......@@ -317,7 +318,7 @@ def test_get_block_descs_ids_hybrid_ssm():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1
worker._mamba_phys_ratio = {engine_id: 1}
worker._physical_blocks_per_logical = {engine_id: 1}
worker.block_len_per_layer = [100]
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks
......@@ -355,7 +356,7 @@ def test_get_block_descs_ids_kernel_block_mismatch():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio
worker._mamba_phys_ratio = {engine_id: ratio}
worker._physical_blocks_per_logical = {engine_id: ratio}
worker.block_len_per_layer = [100]
worker.num_descs = 2 * num_blocks # 800
......@@ -532,15 +533,15 @@ def test_has_mamba_init(
((9216, 524288), 4096, 131),
],
)
def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_mamba_phys_ratio is TP-dependent.
def test_compute_physical_blocks_per_logical(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_physical_blocks_per_logical is TP-dependent.
With dimension-sharded Mamba state, the ratio differs across TP sizes
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
_mamba_phys_ratio must be stored per-engine.
_physical_blocks_per_logical must be stored per-engine.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
compute_mamba_phys_ratio,
compute_physical_blocks_per_logical,
)
assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio
assert compute_physical_blocks_per_logical(ssm_sizes, block_len) == expected_ratio
......@@ -5,7 +5,7 @@ KV cache helper for store.
"""
from collections.abc import Iterator
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast
import torch
......@@ -319,31 +319,139 @@ def yield_req_data(
)
@dataclass
class TpKVTopology:
def get_current_attn_backends(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> list[type[AttentionBackend]]:
"""Get all distinct attention backends for the given layers.
Args:
vllm_config: The current vLLM configuration.
layer_names: Optional list of layer names to scope the lookup.
When None, all attention layers are considered.
Returns:
Deduplicated list of attention backend classes.
"""
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
if layers:
seen: dict[str, type[AttentionBackend]] = {}
for layer in layers.values():
backend = layer.get_attn_backend()
seen[backend.full_cls_name()] = backend
return list(seen.values())
# Fallback for tests, when static_forward_context is empty.
logger.debug(
"No layers found in the vLLM config. Falling back to default attention backend."
)
from vllm.v1.attention.selector import get_attn_backend
return [
get_attn_backend(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
use_mla=vllm_config.model_config.use_mla,
)
]
def get_current_attn_backend(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> type[AttentionBackend]:
"""Get the first attention backend for the given layers."""
return get_current_attn_backends(vllm_config, layer_names)[0]
# ---- Per-engine transfer info ----
@dataclass(frozen=True)
class EngineTransferInfo:
"""Common per-remote-engine transfer state, computed at handshake.
Stored per ``engine_id`` inside ``TransferTopology._engines``.
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
remote_tp_size: int
remote_block_len: int
"""Block length (bytes)"""
remote_block_size: int
"""Tokens per block."""
remote_physical_blocks_per_logical: int
"""Physical blocks per logical block."""
@dataclass(frozen=True)
class MambaEngineTransferInfo(EngineTransferInfo):
"""Extends ``EngineTransferInfo`` with Mamba-hybrid transfer geometry.
For hybrid SSM+Attention models, FA and Mamba layers may require
different numbers of reads from different remote ranks. This
dataclass captures that per-engine transfer plan.
"""
remote_fa_source_ranks: tuple[int, ...]
"""Remote ranks carrying unique FA heads for this local rank."""
remote_all_source_ranks: tuple[int, ...]
"""All remote ranks this local rank reads from (FA + Mamba)."""
remote_num_fa_reads: int
"""Number of distinct remote ranks needed for FA data."""
remote_num_mamba_reads: int
"""Number of distinct remote ranks needed for Mamba data."""
remote_fa_descriptor_bytes: int
"""Byte size of one FA K (or V) descriptor entry."""
is_remote_replicated: bool
"""Whether the remote engine has replicated KV heads
(remote_tp_size > total_num_kv_heads)."""
remote_physical_heads: int
"""Physical KV heads stored per remote rank."""
# ---- Transfer topology ----
@dataclass
class TransferTopology:
"""Single source of truth for local TP identity and per-engine remote info."""
tp_rank: int
remote_tp_size: dict[EngineId, int]
tp_size: int
block_size: int
engine_id: EngineId
is_mla: bool
is_mamba: bool
total_num_kv_heads: int
attn_backends: list[type[AttentionBackend]]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
tensor_shape: torch.Size | None = None
is_mamba: bool = False
def __post_init__(self):
self.local_physical_heads = max(1, self.total_num_kv_heads // self.tp_size)
self._engines: dict[EngineId, EngineTransferInfo] = {}
self._fa_source_sets: dict[EngineId, frozenset[int]] = {}
self._fa_source_indices: dict[EngineId, dict[int, int]] = {}
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
# or num_blocks.
attn_backend = self.attn_backends[0]
if not self.is_mamba:
_MOCK_BLOCK_SIZE = 16
kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1
num_blocks=1,
block_size=_MOCK_BLOCK_SIZE,
num_kv_heads=1,
head_size=1,
)
logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
......@@ -358,11 +466,9 @@ class TpKVTopology:
self._cross_layers_blocks = (
len(self.tensor_shape) == len(kv_cache_shape) + 1
)
self.tensor_shape: torch.Size
if self._cross_layers_blocks:
logger.debug("Using cross-layer KV cache")
# prepend layers dimension
_MOCK_NUM_LAYERS = 80
kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape
try:
......@@ -372,15 +478,81 @@ class TpKVTopology:
except (AttributeError, NotImplementedError):
assert self.tensor_shape is not None
kv_cache_stride_order = tuple(range(len(self.tensor_shape)))
# In case of cross layers permute kv_cache_shape according to
# stride_order to retrieve physical position of block_size
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
# ============================================================
# Engine registration
# ============================================================
def register_remote_engine(
self,
remote_engine_id: EngineId,
remote_tp_size: int,
remote_block_size: int,
remote_block_len: int,
remote_physical_blocks_per_logical: int,
*,
local_block_len: int = 0,
) -> EngineTransferInfo:
"""Register a remote engine, unifying worker dicts state.
Only remote engines should be registered here — the local engine's
identity (tp_size, block_size, etc.) is set via ``__init__`` params.
For Mamba models, also computes the Mamba transfer plan and
builds the FA source lookup caches.
Args:
local_block_len: Local representative block_len (bytes).
Required for Mamba models to compute ``fa_descriptor_bytes``.
"""
assert remote_engine_id != self.engine_id, (
f"Cannot register local engine {self.engine_id} as remote. "
f"Local identity is set via __init__ params."
)
if remote_engine_id in self._engines:
return self._engines[remote_engine_id]
info: EngineTransferInfo
if self.is_mamba:
info = self._build_mamba_info(
remote_tp_size=remote_tp_size,
remote_block_size=remote_block_size,
remote_block_len=remote_block_len,
remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical),
local_block_len=local_block_len,
)
assert isinstance(info, MambaEngineTransferInfo)
self._fa_source_sets[remote_engine_id] = frozenset(
info.remote_fa_source_ranks
)
self._fa_source_indices[remote_engine_id] = {
r: i for i, r in enumerate(info.remote_fa_source_ranks)
}
else:
info = EngineTransferInfo(
remote_tp_size=remote_tp_size,
remote_block_len=remote_block_len,
remote_block_size=remote_block_size,
remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical),
)
self._engines[remote_engine_id] = info
return info
def get_engine_info(self, remote_engine_id: EngineId) -> EngineTransferInfo:
return self._engines[remote_engine_id]
# ============================================================
# Layout properties
# ============================================================
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@property
def cross_layers_blocks(self) -> bool:
return self._cross_layers_blocks
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
......@@ -388,29 +560,16 @@ class TpKVTopology:
self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
)
@property
def tp_size(self) -> int:
return self.remote_tp_size[self.engine_id]
# ============================================================
# Common methods
# ============================================================
@property
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
def tp_ratio(self, remote_tp_size: int) -> int:
"""Calculate the tensor parallel ratio between local and remote TP.
@property
def cross_layers_blocks(self) -> bool:
return self._cross_layers_blocks
def tp_ratio(
self,
remote_tp_size: int,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
Positive when local_tp >= remote_tp (local workers read from the
same remote worker in groups of size ``tp_ratio``). Negative when
remote_tp > local_tp (ratio is flipped).
"""
if self.tp_size >= remote_tp_size:
assert self.tp_size % remote_tp_size == 0, (
......@@ -418,78 +577,65 @@ class TpKVTopology:
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
assert remote_tp_size % self.tp_size == 0, (
f"Remote tensor parallel size {remote_tp_size} is not divisible "
f"by local tensor parallel size {self.tp_size}."
)
# P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
return -(remote_tp_size // self.tp_size)
def block_size_ratio(
self,
remote_block_size: int,
) -> int:
"""
Calculate the block size ratio between local and remote TP.
"""
def block_size_ratio(self, remote_block_size: int) -> int:
"""Calculate the block size ratio between local and remote."""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
return self.block_size // remote_block_size
def tp_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
def is_kv_replicated(self, remote_engine_id: EngineId) -> bool:
"""Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
When they are equal, each TP rank still owns one distinct KV head,
so this is not considered replication.
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size > self.total_num_kv_heads
return self._engines[remote_engine_id].remote_tp_size > self.total_num_kv_heads
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_ranks(
self,
remote_tp_size: int,
) -> list[int]:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
@property
def local_replicates_kv_cache(self) -> bool:
"""Whether the local engine's KV cache is replicated."""
return self.is_mla or self.tp_size > self.total_num_kv_heads
def handshake_target_ranks(self, remote_tp_size: int) -> list[int]:
"""Pre-registration: compute which remote TP ranks to handshake with.
Pure math based on local/remote TP sizes — does not require
the remote engine to be registered yet.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
abs_ratio = -tp_ratio
return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)]
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
def target_remote_ranks(self, remote_engine_id: EngineId) -> list[int]:
"""Get the remote TP rank(s) that the current local TP rank will
read from. When remote tp_size > local tp_size, reads from
multiple remote ranks.
def get_target_remote_ranks_from_engine_id(
self,
remote_engine_id: EngineId,
) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_ranks(remote_tp_size)
For Mamba models, returns the precomputed ``all_source_ranks``
(FA + Mamba union).
"""
info = self._engines[remote_engine_id]
if isinstance(info, MambaEngineTransferInfo):
return list(info.remote_all_source_ranks)
tp_ratio = self.tp_ratio(info.remote_tp_size)
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
# remote TP > local TP: read from |tp_ratio| remote workers
abs_ratio = -tp_ratio
return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)]
def get_transfer_cache_regions(
self, cache: torch.Tensor, layer_spec: "KVCacheSpec"
......@@ -498,331 +644,139 @@ class TpKVTopology:
also accounting for hybrid SSM models specificities.
"""
if isinstance(layer_spec, MambaSpec):
# Register the whole kv cache shared tensor, including SSM/Conv. This is
# similar to FI with the difference that SSM/Conv have different sizes
# Register the whole kv cache shared tensor, including
# SSM/Conv.
conv, ssm = cache
return [conv]
# Check may be hacky but it's matching `_update_hybrid_attention_mamba_layout`.
# Check may be hacky but it's matching
# `_update_hybrid_attention_mamba_layout`.
if self.is_mamba and cache.shape[0] == 2:
# When MAMBA is present, all backends are blocks first, so that blocks
# can be shared between attention layers and mamba layers. Runner
# `_update_hybrid_attention_mamba_layout` already adjusted strides
# for FlashAttn-like backends so its num_blocks first.
# Swap [2<>num_blocks] dims to get required layout for hybrid SSM.
# When MAMBA is present, all backends are blocks first, so
# that blocks can be shared between attention layers and mamba
# layers. Runner already adjusted strides for FlashAttn-like
# backends so its num_blocks first.
# Swap [2<>num_blocks] dims for hybrid SSM layout.
cache = cache.transpose(0, 1)
# Regular case: backends like FA register K/V in separate regions
return cache if self.split_k_and_v else [cache]
# ============================================================
# Mamba-specific methods
# ============================================================
# ---- Mamba-HMA hetero-TP transfer config ----
#
# Key insight: with hetero-TP (P_TP > D_TP), FA KV cache may be
# replicated across P ranks (when P_TP > num_kv_heads), but Mamba
# conv/SSM state is almost always uniquely sharded per P rank. So the
# number of P ranks D must read from can differ between FA and Mamba,
# and they must be handled separately.
def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range:
"""Physical KV head range stored in a rank's KV cache tensor.
When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank.
When ``tp_size > num_heads``: 1 physical head per rank. Heads are
distributed **contiguously** (matching vLLM's GQA weight partitioning):
consecutive ranks share a head before moving to the next one.
"""
if tp_size <= num_heads:
assert num_heads % tp_size == 0
per_rank = num_heads // tp_size
return range(rank * per_rank, (rank + 1) * per_rank)
else:
h = rank * num_heads // tp_size
return range(h, h + 1)
def _range_overlap(a: range, b: range) -> range:
start = max(a.start, b.start)
stop = min(a.stop, b.stop)
return range(start, max(start, stop))
def should_skip_fa(self, remote_engine_id: EngineId, remote_rank: int) -> bool:
"""Whether to skip FA groups for this remote rank (mamba-only)."""
return remote_rank not in self._fa_source_sets[remote_engine_id]
@dataclass
class HeteroTPTransferConfig:
"""Precomputed transfer plan for one (D rank, P engine) pair.
Currently only instantiated for Mamba-HMA (hybrid SSM+Attention) models
where FA and mamba require different splitting factors. Could be extended
to other model types that need non-uniform hetero-TP transfer sizing.
All descriptor sizes are computed here. The guarantee is:
local_entry_size == remote_entry_size (for NIXL)
Attributes that start with ``fa_`` concern FlashAttention KV cache.
Attributes that start with ``mamba_`` concern Mamba conv/SSM state.
"""
# ---- Input parameters (from handshake) ----
tp_ratio: int
K: int # total_num_kv_heads (before TP sharding)
d_tp: int # D engine's tensor_parallel_size
p_tp: int # P engine's tensor_parallel_size
d_rank: int # this D worker's TP rank
use_mla: bool
# Per-layer block lengths (bytes, K+V combined for blocks_first).
# Uniform across layers for current models.
d_block_len: int # D's block_len_per_layer (representative)
p_block_len: int # P's block_len_per_layer (from handshake)
is_blocks_first: bool # kv_topo.is_kv_layout_blocks_first
# ---- Derived: computed in __post_init__ ----
#
# Physical heads per rank (what the KV tensor actually stores)
d_physical_heads: int = field(init=False)
p_physical_heads: int = field(init=False)
# How many distinct P ranks D needs for FA data
physical_fa_num_reads: int = field(init=False)
# Which P ranks contribute unique FA heads (ordered by head index)
fa_read_targets: list[int] = field(init=False)
# All P ranks needed for mamba (always abs_tp for tp_ratio < 0)
mamba_num_reads: int = field(init=False)
# All P ranks this D rank communicates with (FA ∪ mamba)
transfer_targets: list[int] = field(init=False)
# FA descriptor entry size (K or V side, for blocks_first layout)
# Guaranteed: fa_entry_size is the SAME for local handle AND remote desc.
fa_entry_size: int = field(init=False)
# Replication flags
is_d_replicated: bool = field(init=False)
is_p_replicated: bool = field(init=False)
# Pre-built set for fast lookup
_fa_target_set: frozenset[int] = field(init=False, repr=False)
# Map: P rank → index in fa_read_targets (for head slot offset)
_fa_target_index: dict[int, int] = field(init=False, repr=False)
def __post_init__(self) -> None:
K = self.K
self.is_d_replicated = self.d_tp > K
self.is_p_replicated = self.p_tp > K
self.d_physical_heads = max(1, K // self.d_tp)
self.p_physical_heads = max(1, K // self.p_tp)
abs_tp = -self.tp_ratio if self.tp_ratio < 0 else 1
# ---- Mamba range (computed first so FA can prefer ranks in it) ----
mamba_range: range | None = None
if self.tp_ratio < 0:
mamba_range = range(self.d_rank * abs_tp, (self.d_rank + 1) * abs_tp)
# ---- FA read targets ----
if self.use_mla or self.tp_ratio >= 0:
self.physical_fa_num_reads = 1
self.fa_read_targets = (
[0]
if self.use_mla
# Must match kv_topo.get_target_remote_ranks (d_rank // tp_ratio).
else [
self.d_rank // self.tp_ratio if self.tp_ratio > 0 else self.d_rank
]
)
else:
d_needs = _physical_head_range(self.d_tp, K, self.d_rank)
# When mamba range exists, prefer P ranks within it so that
# FA targets are a subset of mamba transfer_targets (avoids
# orphaned FA targets outside the transfer loop).
search_range = mamba_range if mamba_range is not None else range(self.p_tp)
seen: set[tuple[int, int]] = set()
targets: list[int] = []
for p in search_range:
p_has = _physical_head_range(self.p_tp, K, p)
ov = _range_overlap(d_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
targets.append(p)
if not targets:
# Fallback: search globally (should not happen in practice)
for p in range(self.p_tp):
p_has = _physical_head_range(self.p_tp, K, p)
ov = _range_overlap(d_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
targets.append(p)
self.fa_read_targets = targets
self.physical_fa_num_reads = len(targets)
self._fa_target_set = frozenset(self.fa_read_targets)
self._fa_target_index = {r: i for i, r in enumerate(self.fa_read_targets)}
# ---- Mamba targets ----
if mamba_range is not None and abs_tp > self.physical_fa_num_reads:
self.mamba_num_reads = abs_tp
self.transfer_targets = list(mamba_range)
else:
self.mamba_num_reads = self.physical_fa_num_reads
self.transfer_targets = list(self.fa_read_targets)
# ---- FA entry size ----
# For blocks_first: block_len_per_layer includes K+V; // 2 gives K (or V).
# Use min(D, P) because D indexes into P when tp_ratio > 0,
# and P is the natural unit when tp_ratio < 0.
effective_block_len = min(self.d_block_len, self.p_block_len)
if self.is_blocks_first:
self.fa_entry_size = effective_block_len // 2
else:
self.fa_entry_size = effective_block_len
self._validate()
def _validate(self) -> None:
"""Cross-check internal consistency."""
if self.is_d_replicated and self.is_p_replicated and self.tp_ratio > 0:
logger.info(
"Both-replicated hetero-TP: D_TP=%d > P_TP=%d > K=%d. "
"Using d_rank // tp_ratio routing with relative head offset.",
self.d_tp,
self.p_tp,
self.K,
)
# FA targets must be a subset of transfer_targets
tt_set = set(self.transfer_targets)
for t in self.fa_read_targets:
if t not in tt_set:
logger.error(
"FA target P rank %d is NOT in transfer_targets %s. "
"This will cause missed FA reads!",
t,
self.transfer_targets,
)
# For tp_ratio < 0 with blocks_first: D_K_half / reads should == P_K_half
if (
self.is_blocks_first
and self.tp_ratio < 0
and self.physical_fa_num_reads > 0
):
d_k_half = self.d_block_len // 2
p_k_half = self.p_block_len // 2
expected_local = d_k_half // self.physical_fa_num_reads
if expected_local != p_k_half:
logger.warning(
"FA size mismatch: D_K_half=%d / reads=%d = %d, "
"but P_K_half=%d. This may indicate a head count or "
"Mamba-HMA inflation inconsistency.",
d_k_half,
self.physical_fa_num_reads,
expected_local,
p_k_half,
)
def fa_head_slot(self, remote_engine_id: EngineId, remote_rank: int) -> int:
"""Index into local FA block for this remote rank's head data.
# ---- Query methods ----
def should_skip_fa(self, p_rank: int) -> bool:
"""Whether to skip FA groups for this P rank (mamba-only transfer)."""
return p_rank not in self._fa_target_set
def fa_head_slot(self, p_rank: int) -> int:
"""Index into D's FA block for this P rank's head data.
For P ranks in fa_read_targets, returns 0, 1, ..., reads-1.
For P ranks NOT in fa_read_targets (replicated duplicates),
returns the slot of the matching FA target with the same head.
For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1.
For ranks NOT in ``fa_source_ranks`` (replicated duplicates),
returns the slot of the matching source rank with the same head.
"""
if p_rank in self._fa_target_index:
return self._fa_target_index[p_rank]
# Duplicate head: find which fa_target has the same physical head
p_head = _physical_head_range(self.p_tp, self.K, p_rank)
for target in self.fa_read_targets:
t_head = _physical_head_range(self.p_tp, self.K, target)
if _range_overlap(p_head, t_head):
return self._fa_target_index[target]
return 0 # fallback
def fa_rank_offset(self, remote_kv_block_len: int) -> int:
"""Byte offset into P's FA block for this D rank.
When D is replicated (D_TP > K), multiple D ranks share a head.
Computes offset *relative to the target P rank's first head*
so it works regardless of how many heads P has.
When neither side replicates, falls back to tp_rank % tp_ratio.
Returns 0 when D does not index into P's block.
fa_index = self._fa_source_indices[remote_engine_id]
if remote_rank in fa_index:
return fa_index[remote_rank]
mamba_info = self._engines[remote_engine_id]
assert isinstance(mamba_info, MambaEngineTransferInfo)
K = self.total_num_kv_heads
remote_tp = mamba_info.remote_tp_size
r_head = self._physical_head_range(remote_tp, K, remote_rank)
for target in mamba_info.remote_fa_source_ranks:
t_head = self._physical_head_range(remote_tp, K, target)
if self._range_overlap(r_head, t_head):
return fa_index[target]
return 0
def fa_rank_offset(
self, remote_engine_id: EngineId, remote_kv_block_len: int
) -> int:
"""Byte offset into remote FA block for this local rank.
When local TP is replicated (local_tp > K), multiple local ranks
share a head. Computes offset *relative to the target remote
rank's first head* so it works regardless of how many heads the
remote has. Returns 0 when local does not index into remote.
"""
if self.use_mla or self.tp_ratio <= 0:
mamba_info = self._engines[remote_engine_id]
assert isinstance(mamba_info, MambaEngineTransferInfo)
tp_ratio = self.tp_ratio(mamba_info.remote_tp_size)
if self.is_mla or tp_ratio <= 0:
return 0
if self.is_d_replicated:
d_head = self.d_rank * self.K // self.d_tp
p_rank = self.fa_read_targets[0]
p_start = p_rank * self.K // self.p_tp
return (d_head - p_start) * remote_kv_block_len
return self.d_rank % self.tp_ratio * remote_kv_block_len
@property
def needs_split_handles(self) -> bool:
"""Whether per-P-rank split handles are needed.
K = self.total_num_kv_heads
is_local_replicated = self.tp_size > K
if is_local_replicated:
local_head = self.tp_rank * K // self.tp_size
p_rank = mamba_info.remote_fa_source_ranks[0]
p_start = p_rank * K // mamba_info.remote_tp_size
return (local_head - p_start) * remote_kv_block_len
return self.tp_rank % tp_ratio * remote_kv_block_len
def needs_split_handles(self, remote_engine_id: EngineId) -> bool:
"""Whether per-remote-rank split handles are needed.
True when FA and mamba have different read counts, requiring
different splitting factors in the local handle.
"""
return self.tp_ratio < 0 and not self.use_mla and len(self.transfer_targets) > 1
mamba_info = self._engines[remote_engine_id]
assert isinstance(mamba_info, MambaEngineTransferInfo)
tp_ratio = self.tp_ratio(mamba_info.remote_tp_size)
return (
tp_ratio < 0
and not self.is_mla
and len(mamba_info.remote_all_source_ranks) > 1
)
def compute_split_handle_data(
self,
remote_engine_id: EngineId,
src_blocks_data: list[tuple[int, int, int]],
num_fa_descs: int,
abs_tp: int,
) -> list[list[tuple[int, int, int]]]:
"""Compute per-P-rank (addr, len, tp) triples for Mamba-HMA split handles.
"""Per-remote-rank (addr, len, dev) triples for Mamba-HMA split
handles.
FA descriptors (indices < num_fa_descs) are sliced by
``physical_fa_num_reads``; mamba descriptors are sliced uniformly
``remote_num_fa_reads``; mamba descriptors are sliced uniformly
by ``abs_tp``.
Returns one list of triples per transfer target.
"""
mamba_info = self._engines[remote_engine_id]
assert isinstance(mamba_info, MambaEngineTransferInfo)
all_handle_data: list[list[tuple[int, int, int]]] = []
for p_idx, p_rank in enumerate(self.transfer_targets):
for p_idx, p_rank in enumerate(mamba_info.remote_all_source_ranks):
handle_data: list[tuple[int, int, int]] = []
skip_fa = self.should_skip_fa(p_rank)
fa_slot = self.fa_head_slot(p_rank) if not skip_fa else 0
for j, (addr, local_len, tp) in enumerate(src_blocks_data):
skip_fa = self.should_skip_fa(remote_engine_id, p_rank)
fa_slot = self.fa_head_slot(remote_engine_id, p_rank) if not skip_fa else 0
for j, (addr, local_len, dev) in enumerate(src_blocks_data):
if j < num_fa_descs:
assert self.physical_fa_num_reads >= 1
fa_chunk = local_len // self.physical_fa_num_reads
handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, tp))
assert mamba_info.remote_num_fa_reads >= 1
fa_chunk = local_len // mamba_info.remote_num_fa_reads
handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, dev))
else:
mamba_chunk = local_len // abs_tp
handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, tp))
handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, dev))
all_handle_data.append(handle_data)
return all_handle_data
def filter_block_ids_for_rank(
self,
remote_engine_id: EngineId,
remote_rank: int,
local_ids: BlockIds,
remote_ids: BlockIds,
is_mamba_group: list[bool],
) -> tuple[BlockIds, BlockIds]:
"""Zero out FA groups for P ranks outside fa_read_targets.
"""Zero out FA groups for remote ranks outside ``fa_source_ranks``.
Returns (filtered_local_ids, filtered_remote_ids). When the
remote rank carries FA data for this D rank, returns the inputs
unchanged.
remote rank carries FA data for this local rank, returns the
inputs unchanged.
"""
if not self.should_skip_fa(remote_rank):
if not self.should_skip_fa(remote_engine_id, remote_rank):
return local_ids, remote_ids
num_groups = len(local_ids)
filtered_local: list[list[int]] = [
......@@ -833,108 +787,184 @@ class HeteroTPTransferConfig:
]
return filtered_local, filtered_remote
def describe(self) -> str:
"""One-line summary for logging."""
return (
f"HeteroTPTransferConfig("
f"tp_ratio={self.tp_ratio}, K={self.K}, "
f"d_tp={self.d_tp}, p_tp={self.p_tp}, d_rank={self.d_rank}, "
f"physical_fa_reads={self.physical_fa_num_reads}, "
f"mamba_reads={self.mamba_num_reads}, "
f"fa_targets={self.fa_read_targets}, "
f"transfer_targets={self.transfer_targets}, "
f"fa_entry_size={self.fa_entry_size}, "
f"d_block_len={self.d_block_len}, p_block_len={self.p_block_len})"
def describe(self, remote_engine_id: EngineId) -> str:
"""One-line summary of transfer config for logging."""
info = self._engines[remote_engine_id]
base = (
f"tp_ratio={self.tp_ratio(info.remote_tp_size)}, "
f"K={self.total_num_kv_heads}, "
f"local_tp={self.tp_size}, "
f"remote_tp={info.remote_tp_size}, "
f"local_rank={self.tp_rank}, "
f"remote_block_len={info.remote_block_len}"
)
if isinstance(info, MambaEngineTransferInfo):
return (
f"TransferTopology.mamba({base}, "
f"fa_reads={info.remote_num_fa_reads}, "
f"mamba_reads={info.remote_num_mamba_reads}, "
f"fa_sources={list(info.remote_fa_source_ranks)}, "
f"all_sources={list(info.remote_all_source_ranks)}, "
f"fa_desc_bytes={info.remote_fa_descriptor_bytes})"
)
return f"TransferTopology({base})"
# ============================================================
# Private helpers
# ============================================================
# Mamba-HMA hetero-TP transfer config:
# With hetero-TP (P_TP > D_TP), FA KV cache may be replicated across
# P ranks (when P_TP > num_kv_heads), but Mamba conv/SSM state is
# almost always uniquely sharded per P rank. So the number of P
# ranks D must read from can differ between FA and Mamba, and they
# must be handled separately.
@staticmethod
def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range:
"""Physical KV head range stored in a rank's KV cache tensor.
When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank.
When ``tp_size > num_heads``: 1 physical head per rank. Heads are
distributed **contiguously** (matching vLLM's GQA weight partitioning):
consecutive ranks share a head before moving to the next one.
"""
if tp_size <= num_heads:
assert num_heads % tp_size == 0
per_rank = num_heads // tp_size
return range(rank * per_rank, (rank + 1) * per_rank)
else:
h = rank * num_heads // tp_size
return range(h, h + 1)
@staticmethod
def _range_overlap(a: range, b: range) -> range:
start = max(a.start, b.start)
stop = min(a.stop, b.stop)
return range(start, max(start, stop))
def get_current_attn_backends(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> list[type[AttentionBackend]]:
"""Get all distinct attention backends for the given layers.
# ============================================================
# Private: build Mamba transfer info
# ============================================================
Args:
vllm_config: The current vLLM configuration.
layer_names: Optional list of layer names to scope the lookup.
When None, all attention layers are considered.
def _build_mamba_info(
self,
remote_tp_size: int,
remote_block_size: int,
remote_block_len: int,
remote_physical_blocks_per_logical: int,
local_block_len: int,
) -> MambaEngineTransferInfo:
"""Compute Mamba transfer plan."""
K = self.total_num_kv_heads
local_tp = self.tp_size
local_rank = self.tp_rank
is_remote_replicated = remote_tp_size > K
remote_physical_heads = max(1, K // remote_tp_size)
if local_tp >= remote_tp_size:
assert local_tp % remote_tp_size == 0
tp_ratio = local_tp // remote_tp_size
else:
assert remote_tp_size % local_tp == 0
tp_ratio = -(remote_tp_size // local_tp)
Returns:
Deduplicated list of attention backend classes.
"""
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
if layers:
seen: dict[str, type[AttentionBackend]] = {}
for layer in layers.values():
backend = layer.get_attn_backend()
seen[backend.full_cls_name()] = backend
return list(seen.values())
abs_tp = -tp_ratio if tp_ratio < 0 else 1
# Fallback for tests, when static_forward_context is empty.
logger.debug(
"No layers found in the vLLM config. Falling back to default attention backend."
)
from vllm.v1.attention.selector import get_attn_backend
mamba_range: range | None = None
if tp_ratio < 0:
mamba_range = range(local_rank * abs_tp, (local_rank + 1) * abs_tp)
return [
get_attn_backend(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
use_mla=vllm_config.model_config.use_mla,
)
]
# ---- FA read targets ----
if self.is_mla or tp_ratio >= 0:
num_fa_reads = 1
fa_source_ranks: list[int] = (
[0]
if self.is_mla
else [local_rank // tp_ratio if tp_ratio > 0 else local_rank]
)
else:
local_needs = self._physical_head_range(local_tp, K, local_rank)
search_range = (
mamba_range if mamba_range is not None else range(remote_tp_size)
)
seen: set[tuple[int, int]] = set()
fa_source_ranks = []
for p in search_range:
p_has = self._physical_head_range(remote_tp_size, K, p)
ov = self._range_overlap(local_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
fa_source_ranks.append(p)
if not fa_source_ranks:
for p in range(remote_tp_size):
p_has = self._physical_head_range(remote_tp_size, K, p)
ov = self._range_overlap(local_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
fa_source_ranks.append(p)
num_fa_reads = len(fa_source_ranks)
# ---- All source ranks (mamba + FA) ----
if mamba_range is not None and abs_tp > num_fa_reads:
num_mamba_reads = abs_tp
all_source_ranks = list(mamba_range)
else:
num_mamba_reads = num_fa_reads
all_source_ranks = list(fa_source_ranks)
def get_current_attn_backend(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> type[AttentionBackend]:
"""Get the first attention backend for the given layers."""
return get_current_attn_backends(vllm_config, layer_names)[0]
# ---- FA descriptor bytes ----
effective_block_len = min(local_block_len, remote_block_len)
if self.is_kv_layout_blocks_first:
fa_descriptor_bytes = effective_block_len // 2
else:
fa_descriptor_bytes = effective_block_len
# ---- Validation ----
is_local_replicated = local_tp > K
if is_local_replicated and is_remote_replicated and tp_ratio > 0:
logger.info(
"Both-replicated hetero-TP: local_tp=%d > remote_tp=%d > K=%d.",
local_tp,
remote_tp_size,
K,
)
tt_set = set(all_source_ranks)
for t in fa_source_ranks:
if t not in tt_set:
logger.error(
"FA source rank %d NOT in all_source_ranks %s.",
t,
all_source_ranks,
)
if self.is_kv_layout_blocks_first and tp_ratio < 0 and num_fa_reads > 0:
local_k_half = local_block_len // 2
remote_k_half = remote_block_len // 2
expected = local_k_half // num_fa_reads
if expected != remote_k_half:
logger.warning(
"FA size mismatch: local_k_half=%d / reads=%d = %d, "
"but remote_k_half=%d.",
local_k_half,
num_fa_reads,
expected,
remote_k_half,
)
# TODO (ZhanqiuHu): Consolidate TpKVTopology and HeteroTPTransferConfig
# into a single engine-agnostic TransferTopology class.
# 6 of 9 HeteroTPTransferConfig init fields duplicate TpKVTopology data.
#
# @dataclass
# class EngineTransferInfo:
# """Per-remote-engine transfer state, computed at handshake."""
# p_tp: int
# tp_ratio: int
# p_block_len: int
# block_size: int
# # Mamba-specific (None for non-mamba models)
# fa_read_targets: list[int] | None = None
# transfer_targets: list[int] | None = None
# physical_fa_num_reads: int | None = None
# mamba_num_reads: int | None = None
# fa_entry_size: int | None = None
#
# class TransferTopology:
# """Single source of truth for TP topology + transfer sizing."""
# # Shared (set once at init, replaces duplicate fields)
# tp_rank: int # == TpKVTopology.tp_rank == HeteroTP.d_rank
# tp_size: int # == TpKVTopology.tp_size == HeteroTP.d_tp
# total_num_kv_heads: int # == HeteroTP.K
# is_mla: bool # == HeteroTP.use_mla
# is_mamba: bool
# is_blocks_first: bool # == HeteroTP.is_blocks_first
# d_block_len: int
#
# # Per-engine (populated via register_engine() at handshake)
# _engines: dict[EngineId, EngineTransferInfo]
#
# def register_engine(self, engine_id, p_tp, p_block_len, ...): ...
#
# # General (from TpKVTopology)
# def tp_ratio(self, engine_id) -> int: ...
# def target_remote_ranks(self, engine_id) -> list[int]: ...
# def is_kv_replicated(self, engine_id) -> bool: ...
#
# # Mamba-specific (from HeteroTPTransferConfig, gated by is_mamba)
# def fa_rank_offset(self, engine_id, block_len) -> int: ...
# def physical_fa_num_reads(self, engine_id) -> int: ...
# def transfer_targets(self, engine_id) -> list[int]: ...
# def should_skip_fa(self, engine_id, p_rank) -> bool: ...
# def filter_block_ids_for_rank(self, engine_id, ...) -> ...: ...
return MambaEngineTransferInfo(
remote_tp_size=remote_tp_size,
remote_block_len=remote_block_len,
remote_block_size=remote_block_size,
remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical),
remote_fa_source_ranks=tuple(fa_source_ranks),
remote_all_source_ranks=tuple(all_source_ranks),
remote_num_fa_reads=num_fa_reads,
remote_num_mamba_reads=num_mamba_reads,
remote_fa_descriptor_bytes=fa_descriptor_bytes,
is_remote_replicated=is_remote_replicated,
remote_physical_heads=remote_physical_heads,
)
......@@ -21,7 +21,7 @@ from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
TransferTopology,
get_current_attn_backend,
get_current_attn_backends,
)
......@@ -764,13 +764,13 @@ class MooncakeConnectorWorker:
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
self.kv_topo = TpKVTopology(
self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank,
tp_size=self.tp_size,
block_size=self.block_size,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
is_mamba=False,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=[backend],
)
......@@ -911,7 +911,7 @@ class MooncakeConnectorWorker:
self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata
):
pending_reqs: dict[ReqId, SendBlockMeta] = {}
remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size)
remote_tp_ranks = self.transfer_topo.handshake_target_ranks(meta.remote_tp_size)
if meta.remote_tp_rank not in remote_tp_ranks:
# This D worker does not pair with the P worker.
msg = (
......@@ -1256,7 +1256,7 @@ class MooncakeConnectorWorker:
seen_base_addresses = []
self.block_len_per_layer = []
split_k_and_v = self.kv_topo.split_k_and_v
split_k_and_v = self.transfer_topo.split_k_and_v
tensor_size_bytes = None
for layer_name, cache_or_caches in kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
......@@ -1495,8 +1495,8 @@ class MooncakeConnectorWorker:
remote_engine_id: EngineId,
pull_metas: dict[ReqId, PullReqMeta],
):
remote_tp_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
remote_engine_id
remote_tp_ranks = self.transfer_topo.handshake_target_ranks(
self._tp_size[remote_engine_id]
)
count = len(remote_tp_ranks)
logger.debug(
......@@ -1587,7 +1587,7 @@ class MooncakeConnectorWorker:
)
def _producer_cache_is_replicated(self) -> bool:
return self.kv_topo.replicates_kv_cache(self.engine_id)
return self.transfer_topo.local_replicates_kv_cache
def _get_transfer_regions(
self, base_addrs: list[int], block_lens: list[int]
......@@ -1595,7 +1595,7 @@ class MooncakeConnectorWorker:
return _expand_transfer_regions(
base_addrs=base_addrs,
block_lens=block_lens,
is_kv_layout_blocks_first=self.kv_topo.is_kv_layout_blocks_first,
is_kv_layout_blocks_first=self.transfer_topo.is_kv_layout_blocks_first,
)
def _get_sender_transfer_plan(
......
......@@ -21,8 +21,8 @@ from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
BlockIds,
EngineId,
HeteroTPTransferConfig,
TpKVTopology,
MambaEngineTransferInfo,
TransferTopology,
get_current_attn_backends,
kv_postprocess_blksize_and_layout_on_receive,
kv_postprocess_blksize_on_receive,
......@@ -49,7 +49,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import (
)
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
MambaConvSplitInfo,
compute_mamba_phys_ratio,
compute_physical_blocks_per_logical,
derive_mamba_conv_split,
)
from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config
......@@ -269,14 +269,12 @@ class NixlConnectorWorker:
self._registered_descs: list[Any] = []
# ---- Mamba-HMA per-engine state (only used when self._has_mamba) ----
# Per-engine transfer config (source of truth for FA/mamba sizing).
self._transfer_configs: dict[str, HeteroTPTransferConfig] = {}
# NOTE (ZhanqiuHu): _mamba_phys_ratio MUST be per-engine.
# compute_mamba_phys_ratio = ceil((conv_bytes + ssm_bytes) / block_len)
# NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine.
# physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len)
# where conv/ssm bytes are per-TP-rank (dimension-sharded). With
# heterogeneous TP the per-rank sizes differ, so the ratio differs:
# e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261.
self._mamba_phys_ratio: dict[EngineId, int] = {}
self._physical_blocks_per_logical: dict[EngineId, int] = {}
# In progress transfers.
# [req_id -> list[handle]]
......@@ -322,10 +320,8 @@ class NixlConnectorWorker:
# lazy initialized in register_kv_caches
self.compat_hash: str | None = None
self.kv_topo: TpKVTopology | None = None
self.transfer_topo: TransferTopology | None = None
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
......@@ -355,7 +351,6 @@ class NixlConnectorWorker:
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
self._block_size[self.engine_id] = kernel_block_size
self.num_blocks *= self._physical_blocks_per_logical_kv_block
def _nixl_handshake(
......@@ -384,8 +379,8 @@ class NixlConnectorWorker:
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
assert self.kv_topo is not None
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
assert self.transfer_topo is not None
p_remote_ranks = self.transfer_topo.handshake_target_ranks(remote_tp_size)
remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port)
......@@ -649,11 +644,11 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
self.kv_topo = TpKVTopology(
self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank,
tp_size=self.world_size,
block_size=self.block_size,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=self.attn_backends,
......@@ -664,7 +659,7 @@ class NixlConnectorWorker:
is_mamba=self._has_mamba,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks
)
if self.use_host_buffer:
......@@ -716,7 +711,7 @@ class NixlConnectorWorker:
if isinstance(layer_spec, UniformTypeKVCacheSpecs):
# MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs
layer_spec = layer_spec.kv_cache_specs[layer_name]
cache_list = self.kv_topo.get_transfer_cache_regions(
cache_list = self.transfer_topo.get_transfer_cache_regions(
cache_or_caches, layer_spec
)
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is
......@@ -729,7 +724,7 @@ class NixlConnectorWorker:
)
# For when registering multiple tensors eg K/V in separate regions.
physical_page_size = physical_page_size // len(cache_list)
if self.kv_topo._cross_layers_blocks:
if self.transfer_topo._cross_layers_blocks:
# When cross-layers blocks are used, multiply by number of layers
physical_page_size = physical_page_size * len(
self.kv_cache_config.kv_cache_tensors
......@@ -793,7 +788,7 @@ class NixlConnectorWorker:
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data)
if self.kv_topo.is_kv_layout_blocks_first:
if self.transfer_topo.is_kv_layout_blocks_first:
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
......@@ -817,7 +812,7 @@ class NixlConnectorWorker:
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._has_mamba:
self._mamba_phys_ratio[self.engine_id] = (
self._physical_blocks_per_logical[self.engine_id] = (
self._physical_blocks_per_logical_kv_block
)
logger.info(
......@@ -876,11 +871,13 @@ class NixlConnectorWorker:
conv_offsets = self._conv_decomp.local_conv_offsets
conv_size, ssm_size = self._mamba_ssm_size
num_blocks = self._logical_num_blocks * block_size_ratio
phys_ratio = self._physical_blocks_per_logical_kv_block
physical_per_logical = self._physical_blocks_per_logical_kv_block
result: list[tuple[int, int, int]] = []
for i, base_addr in enumerate(base_addresses):
page_stride = self.block_len_per_layer[i] // block_size_ratio * phys_ratio
page_stride = (
self.block_len_per_layer[i] // block_size_ratio * physical_per_logical
)
for off, sz in conv_offsets:
for blk in range(num_blocks):
result.append(
......@@ -900,14 +897,14 @@ class NixlConnectorWorker:
def _build_fa_remote_for_mamba(
self,
nixl_agent_meta: NixlAgentMetadata,
transfer_cfg: HeteroTPTransferConfig,
block_size_ratio: int,
kv_topo: TpKVTopology,
transfer_topo: TransferTopology,
remote_engine_id: EngineId,
) -> list[tuple[int, int, int]]:
"""Build remote FA descriptors for mamba models.
Uses transfer_cfg for GQA-aware FA divisor and head-based rank offset
instead of the standard uniform tp_ratio split.
Uses TransferTopology for GQA-aware FA divisor and head-based rank
offset instead of the standard uniform tp_ratio split.
"""
assert block_size_ratio == 1, (
"Mamba 3-read transfer with block_size_ratio != 1 is not tested. "
......@@ -915,7 +912,9 @@ class NixlConnectorWorker:
)
# TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA
# hetero-TP logic stabilizes.
tp_ratio = transfer_cfg.tp_ratio
mamba_info = transfer_topo.get_engine_info(remote_engine_id)
assert isinstance(mamba_info, MambaEngineTransferInfo)
tp_ratio = transfer_topo.tp_ratio(mamba_info.remote_tp_size)
result: list[tuple[int, int, int]] = []
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
local_block_len = self.get_backend_aware_kv_block_len(
......@@ -926,9 +925,11 @@ class NixlConnectorWorker:
local_block_len = remote_kv_block_len
if tp_ratio < 0 and not self.use_mla:
local_block_len = local_block_len // transfer_cfg.physical_fa_num_reads
local_block_len = local_block_len // mamba_info.remote_num_fa_reads
rank_offset = transfer_cfg.fa_rank_offset(remote_kv_block_len)
rank_offset = transfer_topo.fa_rank_offset(
remote_engine_id, remote_kv_block_len
)
num_blocks = nixl_agent_meta.num_blocks
page_size = nixl_agent_meta.block_lens[i]
......@@ -937,12 +938,12 @@ class NixlConnectorWorker:
addr = base_addr + block_offset + rank_offset
result.append((addr, local_block_len, nixl_agent_meta.device_id))
if kv_topo.is_kv_layout_blocks_first:
if transfer_topo.is_kv_layout_blocks_first:
second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=False
)
if tp_ratio < 0 and not self.use_mla:
second_split = second_split // transfer_cfg.physical_fa_num_reads
second_split = second_split // mamba_info.remote_num_fa_reads
for block_id in range(num_blocks):
block_offset = block_id * page_size
addr = base_addr + block_offset + rank_offset
......@@ -981,15 +982,17 @@ class NixlConnectorWorker:
conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)]
ssm_read_size = nixl_agent_meta.ssm_sizes[1]
remote_ratio = self._mamba_phys_ratio[nixl_agent_meta.engine_id]
num_blocks = nixl_agent_meta.num_blocks // remote_ratio
remote_physical_per_logical = self._physical_blocks_per_logical[
nixl_agent_meta.engine_id
]
num_blocks = nixl_agent_meta.num_blocks // remote_physical_per_logical
device_id = nixl_agent_meta.device_id
result: list[tuple[int, int, int]] = []
# NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case
# block lengths vary across layers (e.g. MLA).
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
page_stride = nixl_agent_meta.block_lens[i] * remote_ratio
page_stride = nixl_agent_meta.block_lens[i] * remote_physical_per_logical
for off, sz in conv_offsets:
for blk in range(num_blocks):
result.append((base_addr + blk * page_stride + off, sz, device_id))
......@@ -1019,8 +1022,8 @@ class NixlConnectorWorker:
register another local_xfer_handler using remote block len to ensure
data copy correctness.
"""
assert self.kv_topo is not None
kv_topo = self.kv_topo
assert self.transfer_topo is not None
transfer_topo = self.transfer_topo
block_size_ratio = self.block_size // block_size
blocks_data: list[tuple[int, int, int]] = []
......@@ -1051,7 +1054,7 @@ class NixlConnectorWorker:
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.device_id))
if kv_topo.is_kv_layout_blocks_first:
if transfer_topo.is_kv_layout_blocks_first:
second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=mamba
)
......@@ -1153,11 +1156,29 @@ class NixlConnectorWorker:
)
return self._remote_agents[engine_id][remote_tp_rank]
### Register remote agent metadata
if engine_id not in self._tp_size:
self._tp_size[engine_id] = remote_tp_size
if engine_id not in self._block_size:
self._block_size[engine_id] = nixl_agent_meta.block_size
### Register remote engine in TransferTopology (idempotent).
assert self.transfer_topo is not None
transfer_topo = self.transfer_topo
physical_blocks_per_logical = (
compute_physical_blocks_per_logical(
nixl_agent_meta.ssm_sizes,
nixl_agent_meta.block_lens[0],
)
if self._has_mamba
else 1
)
transfer_topo.register_remote_engine(
remote_engine_id=engine_id,
remote_tp_size=remote_tp_size,
remote_block_size=nixl_agent_meta.block_size,
remote_block_len=nixl_agent_meta.block_lens[0],
remote_physical_blocks_per_logical=physical_blocks_per_logical,
local_block_len=self.block_len_per_layer[0],
)
if self._has_mamba and engine_id not in self._physical_blocks_per_logical:
self._physical_blocks_per_logical[engine_id] = physical_blocks_per_logical
logger.info("Transfer plan: %s", transfer_topo.describe(engine_id))
remote_agent_name = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata
......@@ -1170,16 +1191,10 @@ class NixlConnectorWorker:
# remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
# local origin:| 0| 1| 8| 12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
assert self.kv_topo is not None
kv_topo = self.kv_topo
block_size_ratio = kv_topo.block_size_ratio_from_engine_id(engine_id)
block_size_ratio = transfer_topo.block_size_ratio(nixl_agent_meta.block_size)
if engine_id not in self.dst_num_blocks:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
if self._has_mamba:
self._mamba_phys_ratio[engine_id] = compute_mamba_phys_ratio(
nixl_agent_meta.ssm_sizes, nixl_agent_meta.block_lens[0]
)
# Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
......@@ -1189,28 +1204,13 @@ class NixlConnectorWorker:
# This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
# this is the ratio between the two sizes.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
tp_ratio = transfer_topo.tp_ratio(remote_tp_size)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote = (
not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
not transfer_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
)
# Create transfer config (single source of truth for descriptor sizes).
if self._has_mamba and engine_id not in self._transfer_configs:
self._transfer_configs[engine_id] = HeteroTPTransferConfig(
tp_ratio=tp_ratio,
K=kv_topo.total_num_kv_heads,
d_tp=self.world_size,
p_tp=remote_tp_size,
d_rank=self.tp_rank,
use_mla=self.use_mla,
d_block_len=self.block_len_per_layer[0],
p_block_len=nixl_agent_meta.block_lens[0],
is_blocks_first=kv_topo.is_kv_layout_blocks_first,
)
logger.info("Created %s", self._transfer_configs[engine_id].describe())
logger.debug(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
engine_id,
......@@ -1231,12 +1231,10 @@ class NixlConnectorWorker:
self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
if self._has_mamba:
transfer_cfg = self._transfer_configs.get(engine_id)
assert transfer_cfg is not None
if transfer_cfg.needs_split_handles:
if transfer_topo.needs_split_handles(engine_id):
# Mamba-HMA: FA and Mamba use different split factors.
for handle_data in transfer_cfg.compute_split_handle_data(
self.src_blocks_data, self.num_descs, abs_tp
for handle_data in transfer_topo.compute_split_handle_data(
engine_id, self.src_blocks_data, self.num_descs, abs_tp
):
descs = self.nixl_wrapper.get_xfer_descs(
handle_data, self.nixl_memory_type
......@@ -1247,12 +1245,8 @@ class NixlConnectorWorker:
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
logger.info(
"Mamba-HMA split handles: targets=%s, fa_reads=%s, "
"fa_entry=%s, mamba_reads=%s, num_descs=%s",
transfer_cfg.transfer_targets,
transfer_cfg.physical_fa_num_reads,
transfer_cfg.fa_entry_size,
transfer_cfg.mamba_num_reads,
"Mamba-HMA split handles: %s, num_descs=%s",
transfer_topo.describe(engine_id),
self.num_descs,
)
else:
......@@ -1321,7 +1315,7 @@ class NixlConnectorWorker:
(addr, local_block_len, nixl_agent_meta.device_id)
)
if kv_topo.is_kv_layout_blocks_first:
if transfer_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting.
second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=mamba
......@@ -1360,14 +1354,12 @@ class NixlConnectorWorker:
engine_id,
remote_tp_rank,
)
transfer_cfg = self._transfer_configs.get(engine_id)
assert transfer_cfg is not None
blocks_data.extend(
self._build_fa_remote_for_mamba(
nixl_agent_meta,
transfer_cfg,
block_size_ratio,
kv_topo,
transfer_topo,
engine_id,
)
)
blocks_data.extend(
......@@ -1403,18 +1395,19 @@ class NixlConnectorWorker:
"""
remote_engine_id = nixl_agent_meta.engine_id
assert self._tp_size[remote_engine_id] == remote_tp_size
assert self.kv_topo is not None
assert self.transfer_topo is not None
remote_info = self.transfer_topo.get_engine_info(remote_engine_id)
assert remote_info.remote_tp_size == remote_tp_size
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
remote_engine_id
tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size)
block_size_ratio = self.transfer_topo.block_size_ratio(
nixl_agent_meta.block_size
)
# num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba.
# Mamba models can have replicated FA KV with tp_ratio < 0.
if not self._has_mamba:
assert not (
tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)
tp_ratio < 0 and self.transfer_topo.is_kv_replicated(remote_engine_id)
)
if self._is_hma_required:
......@@ -1467,7 +1460,7 @@ class NixlConnectorWorker:
if (
abs(tp_ratio) != 1
and not self.use_mla
and not self.kv_topo.is_kv_replicated(remote_engine_id)
and not self.transfer_topo.is_kv_replicated(remote_engine_id)
and kv_cache_layout != "HND"
and not self.enable_permute_local_kv
):
......@@ -1478,7 +1471,7 @@ class NixlConnectorWorker:
# Block len can only vary across layers when using MLA.
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
if self.use_mla or self.transfer_topo.is_kv_replicated(remote_engine_id):
# With replicated KV cache, only the number of blocks can differ.
# TODO (ZhanqiuHu): For mamba models, validate FA and mamba
# block_lens separately.
......@@ -1594,7 +1587,7 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0:
return
assert block_size_ratio >= 1, "Only nP < nD supported currently."
assert self.kv_topo is not None
assert self.transfer_topo is not None
if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug(
"Post-processing device kv cache on receive by converting "
......@@ -1614,7 +1607,7 @@ class NixlConnectorWorker:
block_size_ratio,
)
split_k_and_v = self.kv_topo.split_k_and_v
split_k_and_v = self.transfer_topo.split_k_and_v
for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
......@@ -1661,7 +1654,7 @@ class NixlConnectorWorker:
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
assert self.kv_topo is not None
assert self.transfer_topo is not None
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
......@@ -1689,8 +1682,9 @@ class NixlConnectorWorker:
self.sync_recved_kv_to_device(req_id, meta)
# post processing for heteroblocksize
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
meta.remote.engine_id
remote_info = self.transfer_topo.get_engine_info(meta.remote.engine_id)
block_size_ratio = self.transfer_topo.block_size_ratio(
remote_info.remote_block_size
)
if not self.use_mla and (
block_size_ratio > 1 or self.enable_permute_local_kv
......@@ -1741,7 +1735,7 @@ class NixlConnectorWorker:
are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling.
"""
assert self.kv_topo is not None
assert self.transfer_topo is not None
notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
......@@ -1760,7 +1754,7 @@ class NixlConnectorWorker:
# NOTE: `tp_ratio` is the opposite when swapping local<>remote
n_consumers = int(tp_size)
tp_ratio = self.kv_topo.tp_ratio(n_consumers)
tp_ratio = self.transfer_topo.tp_ratio(n_consumers)
# Number of reads *per producer* to wait for.
# When remote D TP > local P TP we expect `tp_ratio` reads.
......@@ -1901,17 +1895,17 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None and self.kv_topo is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
meta.remote.engine_id
)
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)
assert meta.remote is not None and self.transfer_topo is not None
engine_id = meta.remote.engine_id
remote_ranks = self.transfer_topo.target_remote_ranks(engine_id)
remote_info = self.transfer_topo.get_engine_info(engine_id)
tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size)
if self._has_mamba:
# Expand remote logical → kernel block IDs.
meta.remote.block_ids = self._logical_to_remote_kernel_block_ids(
meta.remote.block_ids,
self._mamba_phys_ratio[meta.remote.engine_id],
self._physical_blocks_per_logical[meta.remote.engine_id],
)
else:
meta.remote.block_ids = self._logical_to_kernel_block_ids(
......@@ -1924,7 +1918,7 @@ class NixlConnectorWorker:
# the first remote rank (cache is duplicated)..
break
remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id]
remote_block_size = remote_info.remote_block_size
logger.debug(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s with remote block size %s for req %s",
......@@ -1955,9 +1949,8 @@ class NixlConnectorWorker:
remote_ids: BlockIds = meta.remote.block_ids
if self._has_mamba:
# Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets.
transfer_cfg = self._transfer_configs.get(meta.remote.engine_id)
assert transfer_cfg is not None
local_ids, remote_ids = transfer_cfg.filter_block_ids_for_rank(
local_ids, remote_ids = self.transfer_topo.filter_block_ids_for_rank(
engine_id,
remote_rank,
local_ids,
remote_ids,
......@@ -1999,8 +1992,11 @@ class NixlConnectorWorker:
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
assert self.transfer_topo is not None
remote_info = self.transfer_topo.get_engine_info(dst_engine_id)
block_size_ratio = self.transfer_topo.block_size_ratio(
remote_info.remote_block_size
)
if block_size_ratio > 1:
# TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups.
assert not self._is_hma_required
......@@ -2190,8 +2186,8 @@ class NixlConnectorWorker:
# This is like having two "low-level views" of the same storage.
# `num_fa_descs` offset must be computed per-engine since P and D can
# have different num_blocks (and thus different FA descs counts).
ratio = self._mamba_phys_ratio[engine_id]
logical_blocks = num_blocks // ratio
physical_per_logical = self._physical_blocks_per_logical[engine_id]
logical_blocks = num_blocks // physical_per_logical
num_fa_descs = self.num_regions * num_blocks
# 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm).
mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None]
......@@ -2234,21 +2230,22 @@ class NixlConnectorWorker:
]
def _logical_to_remote_kernel_block_ids(
self, block_ids: BlockIds, remote_ratio: int
self, block_ids: BlockIds, remote_physical_per_logical: int
) -> BlockIds:
"""Map logical block IDs to physical kernel block IDs on the remote.
Args:
block_ids: per-group lists of logical block IDs.
remote_ratio: remote engine's physical blocks per logical block.
remote_physical_per_logical: remote engine's physical blocks
per logical block.
Returns:
Same structure with FA groups expanded (each logical block L
becomes kernel blocks [L*remote_ratio .. L*remote_ratio +
local_ratio - 1]). Mamba groups are passed through unchanged.
becomes kernel blocks [L*ratio .. L*ratio + local_ratio - 1]).
Mamba groups are passed through unchanged.
"""
local_ratio = self._physical_blocks_per_logical_kv_block
if remote_ratio == 1:
if remote_physical_per_logical == 1:
return block_ids
local_arange = np.arange(local_ratio).reshape(1, -1)
group_specs = self.kv_cache_config.kv_cache_groups
......@@ -2256,7 +2253,7 @@ class NixlConnectorWorker:
for i, group in enumerate(block_ids):
if not isinstance(group_specs[i].kv_cache_spec, MambaSpec):
arr = np.array(group).reshape(-1, 1)
expanded = (arr * remote_ratio + local_arange).flatten()
expanded = (arr * remote_physical_per_logical + local_arange).flatten()
result.append(expanded.tolist())
else:
# Mamba blocks are 1:1 logical-to-physical (no expansion).
......@@ -2296,8 +2293,8 @@ class NixlConnectorWorker:
+-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split |
"""
assert self.kv_topo is not None
if self.kv_topo.is_kv_layout_blocks_first:
assert self.transfer_topo is not None
if self.transfer_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part).
if mamba_view:
# NOTE (NickLucche) Mamba Opt: this is already skipping the padding so
......
......@@ -151,7 +151,9 @@ def derive_mamba_conv_split(
)
def compute_mamba_phys_ratio(ssm_sizes: tuple[int, ...], block_len: int) -> int:
def compute_physical_blocks_per_logical(
ssm_sizes: tuple[int, ...], block_len: int
) -> int:
"""Derive _physical_blocks_per_logical_kv_block from remote metadata.
The remote engine's ratio is not sent directly in the handshake, so we
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment