"docs/vscode:/vscode.git/clone" did not exist on "40218a82bad0bc772d75551a8009799f1a001db7"
Unverified Commit cc3993b0 authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

nixl refactor [2/N]: unify TpKVTopology + HeteroTPTransferConfig into TransferTopology (#39529)


Signed-off-by: default avatarZhanqiu Hu <zhu@redhat.com>
parent 50dd4cb4
...@@ -631,7 +631,7 @@ def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes(): ...@@ -631,7 +631,7 @@ def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes():
mock_thread.return_value.is_alive.return_value = False mock_thread.return_value.is_alive.return_value = False
worker.use_mla = True worker.use_mla = True
worker.kv_topo.is_mla = True worker.transfer_topo.is_mla = True
# MLA cache tensor: shape[-2] is the block size. # MLA cache tensor: shape[-2] is the block size.
mla_cache = torch.zeros((2, 16, 96), dtype=torch.float16) mla_cache = torch.zeros((2, 16, 96), dtype=torch.float16)
...@@ -692,9 +692,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): ...@@ -692,9 +692,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
# Override TP rank/size to simulate P TP=2 # Override TP rank/size to simulate P TP=2
prefill_worker.tp_rank = P_TP_RANK prefill_worker.tp_rank = P_TP_RANK
prefill_worker.tp_size = P_TP_SIZE prefill_worker.tp_size = P_TP_SIZE
# Update shared dict so kv_topo sees correct TP size
prefill_worker._tp_size[prefill_worker.engine_id] = P_TP_SIZE prefill_worker._tp_size[prefill_worker.engine_id] = P_TP_SIZE
prefill_worker.kv_topo.tp_rank = P_TP_RANK prefill_worker.transfer_topo.tp_rank = P_TP_RANK
prefill_worker.transfer_topo.tp_size = P_TP_SIZE
prefill_worker.kv_caches_base_addr = [0x1000] prefill_worker.kv_caches_base_addr = [0x1000]
prefill_worker.block_len_per_layer = [local_block_len] prefill_worker.block_len_per_layer = [local_block_len]
...@@ -714,7 +714,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): ...@@ -714,7 +714,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
send_meta.ready.set() send_meta.ready.set()
# Compute target D ranks using the production code path # Compute target D ranks using the production code path
target_d_ranks = prefill_worker.kv_topo.get_target_remote_ranks(d_tp_size) target_d_ranks = prefill_worker.transfer_topo.handshake_target_ranks(d_tp_size)
mock_socket = AsyncMock(spec=zmq.asyncio.Socket) mock_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_socket.send_multipart = AsyncMock() mock_socket.send_multipart = AsyncMock()
......
...@@ -21,7 +21,7 @@ from vllm import LLM ...@@ -21,7 +21,7 @@ from vllm import LLM
from vllm.config import KVTransferConfig, set_current_vllm_config from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator, KVOutputAggregator,
TpKVTopology, TransferTopology,
get_current_attn_backend, get_current_attn_backend,
) )
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl from vllm.distributed.kv_transfer.kv_connector.v1 import nixl
...@@ -463,19 +463,20 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -463,19 +463,20 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
test_shape = self.attn_backends[0].get_kv_cache_shape( test_shape = self.attn_backends[0].get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
) )
self.kv_topo = TpKVTopology( self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.world_size,
block_size=self.block_size,
engine_id=self.engine_id, engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla, is_mla=self.use_mla,
is_mamba=False,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(), total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=self.attn_backends, attn_backends=self.attn_backends,
tensor_shape=test_shape, tensor_shape=test_shape,
) )
self.compat_hash = compute_nixl_compatibility_hash( self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks
) )
def _nixl_handshake( def _nixl_handshake(
...@@ -496,7 +497,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -496,7 +497,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
# Adjust remote block length metadata to satisfy heterogeneous TP # Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation. # invariants enforced during handshake validation.
remote_block_lens = list(self.block_len_per_layer) remote_block_lens = list(self.block_len_per_layer)
tp_ratio = self.kv_topo.tp_ratio(remote_tp_size) tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size)
if remote_tp_size > self.world_size: if remote_tp_size > self.world_size:
# P TP > D TP case, block_len of remote is smaller # P TP > D TP case, block_len of remote is smaller
remote_block_lens = [ remote_block_lens = [
...@@ -731,8 +732,9 @@ class TestNixlHandshake: ...@@ -731,8 +732,9 @@ class TestNixlHandshake:
assert set(remote_agents.keys()) == set(range(tp_ratio)) assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size remote_info = worker.transfer_topo.get_engine_info(remote_engine_id)
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id) assert remote_info.remote_tp_size == remote_tp_size
assert -tp_ratio == worker.transfer_topo.tp_ratio(remote_tp_size)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks # ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
...@@ -796,7 +798,7 @@ class TestNixlHandshake: ...@@ -796,7 +798,7 @@ class TestNixlHandshake:
(conn_p0.connector_worker, conn_p1.connector_worker) (conn_p0.connector_worker, conn_p1.connector_worker)
): ):
worker.world_size = p_tp_size worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size} worker.transfer_topo.tp_size = p_tp_size
worker.tp_rank = rank worker.tp_rank = rank
worker.use_mla = True worker.use_mla = True
...@@ -2337,7 +2339,7 @@ def test_compatibility_hash_validation( ...@@ -2337,7 +2339,7 @@ def test_compatibility_hash_validation(
remote_hash = compute_nixl_compatibility_hash( remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config, remote_vllm_config,
decode_worker.backend_name, decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks, decode_worker.transfer_topo.cross_layers_blocks,
) )
prefill_block_size = config_overrides.get("block_size", 16) prefill_block_size = config_overrides.get("block_size", 16)
...@@ -2424,12 +2426,13 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) ...@@ -2424,12 +2426,13 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
test_shape = backend.get_kv_cache_shape( test_shape = backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
) )
decode_worker.kv_topo = TpKVTopology( decode_worker.transfer_topo = TransferTopology(
tp_rank=decode_worker.tp_rank, tp_rank=decode_worker.tp_rank,
tp_size=decode_worker.world_size,
block_size=decode_worker.block_size,
engine_id=decode_worker.engine_id, engine_id=decode_worker.engine_id,
remote_tp_size=decode_worker._tp_size, # shared state
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla, is_mla=decode_worker.use_mla,
is_mamba=False,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(), total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backends=[backend], attn_backends=[backend],
tensor_shape=test_shape, tensor_shape=test_shape,
...@@ -2438,7 +2441,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) ...@@ -2438,7 +2441,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_worker.compat_hash = compute_nixl_compatibility_hash( decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config, decode_worker.vllm_config,
decode_worker.backend_name, decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks, decode_worker.transfer_topo.cross_layers_blocks,
) )
if error_scenario == "handshake_decode_error": if error_scenario == "handshake_decode_error":
......
...@@ -152,13 +152,14 @@ def test_read_blocks_for_req_expands_remote_ids( ...@@ -152,13 +152,14 @@ def test_read_blocks_for_req_expands_remote_ids(
remote_engine_id = "remote-engine" remote_engine_id = "remote-engine"
if has_mamba: if has_mamba:
worker._mamba_phys_ratio = {remote_engine_id: remote_ratio} worker._physical_blocks_per_logical = {remote_engine_id: remote_ratio}
# Mock kv_topo: empty remote ranks skips the transfer machinery entirely, # Mock transfer_topo: empty remote ranks skips the transfer machinery
# isolating the block-ID expansion logic. # entirely, isolating the block-ID expansion logic.
worker.kv_topo = MagicMock() worker.transfer_topo = MagicMock()
worker.kv_topo.get_target_remote_ranks_from_engine_id.return_value = [] worker.transfer_topo.target_remote_ranks.return_value = []
worker.kv_topo.tp_ratio_from_engine_id.return_value = 1 worker.transfer_topo.get_engine_info.return_value = MagicMock(remote_tp_size=1)
worker.transfer_topo.tp_ratio.return_value = 1
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv( metadata.add_new_req_to_recv(
...@@ -317,7 +318,7 @@ def test_get_block_descs_ids_hybrid_ssm(): ...@@ -317,7 +318,7 @@ def test_get_block_descs_ids_hybrid_ssm():
worker._has_mamba = True worker._has_mamba = True
worker._is_mamba_group = [False, True] worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1 worker._physical_blocks_per_logical_kv_block = 1
worker._mamba_phys_ratio = {engine_id: 1} worker._physical_blocks_per_logical = {engine_id: 1}
worker.block_len_per_layer = [100] worker.block_len_per_layer = [100]
# num_descs = num_regions * num_blocks (no blocks_first doubling) # num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks worker.num_descs = 2 * num_blocks
...@@ -355,7 +356,7 @@ def test_get_block_descs_ids_kernel_block_mismatch(): ...@@ -355,7 +356,7 @@ def test_get_block_descs_ids_kernel_block_mismatch():
worker._has_mamba = True worker._has_mamba = True
worker._is_mamba_group = [False, True] worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio worker._physical_blocks_per_logical_kv_block = ratio
worker._mamba_phys_ratio = {engine_id: ratio} worker._physical_blocks_per_logical = {engine_id: ratio}
worker.block_len_per_layer = [100] worker.block_len_per_layer = [100]
worker.num_descs = 2 * num_blocks # 800 worker.num_descs = 2 * num_blocks # 800
...@@ -532,15 +533,15 @@ def test_has_mamba_init( ...@@ -532,15 +533,15 @@ def test_has_mamba_init(
((9216, 524288), 4096, 131), ((9216, 524288), 4096, 131),
], ],
) )
def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio): def test_compute_physical_blocks_per_logical(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_mamba_phys_ratio is TP-dependent. """Verify that compute_physical_blocks_per_logical is TP-dependent.
With dimension-sharded Mamba state, the ratio differs across TP sizes With dimension-sharded Mamba state, the ratio differs across TP sizes
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why (e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
_mamba_phys_ratio must be stored per-engine. _physical_blocks_per_logical must be stored per-engine.
""" """
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
compute_mamba_phys_ratio, compute_physical_blocks_per_logical,
) )
assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio assert compute_physical_blocks_per_logical(ssm_sizes, block_len) == expected_ratio
...@@ -5,7 +5,7 @@ KV cache helper for store. ...@@ -5,7 +5,7 @@ KV cache helper for store.
""" """
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
import torch import torch
...@@ -319,31 +319,139 @@ def yield_req_data( ...@@ -319,31 +319,139 @@ def yield_req_data(
) )
@dataclass def get_current_attn_backends(
class TpKVTopology: vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> list[type[AttentionBackend]]:
"""Get all distinct attention backends for the given layers.
Args:
vllm_config: The current vLLM configuration.
layer_names: Optional list of layer names to scope the lookup.
When None, all attention layers are considered.
Returns:
Deduplicated list of attention backend classes.
"""
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
if layers:
seen: dict[str, type[AttentionBackend]] = {}
for layer in layers.values():
backend = layer.get_attn_backend()
seen[backend.full_cls_name()] = backend
return list(seen.values())
# Fallback for tests, when static_forward_context is empty.
logger.debug(
"No layers found in the vLLM config. Falling back to default attention backend."
)
from vllm.v1.attention.selector import get_attn_backend
return [
get_attn_backend(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
use_mla=vllm_config.model_config.use_mla,
)
]
def get_current_attn_backend(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> type[AttentionBackend]:
"""Get the first attention backend for the given layers."""
return get_current_attn_backends(vllm_config, layer_names)[0]
# ---- Per-engine transfer info ----
@dataclass(frozen=True)
class EngineTransferInfo:
"""Common per-remote-engine transfer state, computed at handshake.
Stored per ``engine_id`` inside ``TransferTopology._engines``.
""" """
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers. remote_tp_size: int
remote_block_len: int
"""Block length (bytes)"""
remote_block_size: int
"""Tokens per block."""
remote_physical_blocks_per_logical: int
"""Physical blocks per logical block."""
@dataclass(frozen=True)
class MambaEngineTransferInfo(EngineTransferInfo):
"""Extends ``EngineTransferInfo`` with Mamba-hybrid transfer geometry.
For hybrid SSM+Attention models, FA and Mamba layers may require
different numbers of reads from different remote ranks. This
dataclass captures that per-engine transfer plan.
""" """
remote_fa_source_ranks: tuple[int, ...]
"""Remote ranks carrying unique FA heads for this local rank."""
remote_all_source_ranks: tuple[int, ...]
"""All remote ranks this local rank reads from (FA + Mamba)."""
remote_num_fa_reads: int
"""Number of distinct remote ranks needed for FA data."""
remote_num_mamba_reads: int
"""Number of distinct remote ranks needed for Mamba data."""
remote_fa_descriptor_bytes: int
"""Byte size of one FA K (or V) descriptor entry."""
is_remote_replicated: bool
"""Whether the remote engine has replicated KV heads
(remote_tp_size > total_num_kv_heads)."""
remote_physical_heads: int
"""Physical KV heads stored per remote rank."""
# ---- Transfer topology ----
@dataclass
class TransferTopology:
"""Single source of truth for local TP identity and per-engine remote info."""
tp_rank: int tp_rank: int
remote_tp_size: dict[EngineId, int] tp_size: int
block_size: int
engine_id: EngineId
is_mla: bool is_mla: bool
is_mamba: bool
total_num_kv_heads: int total_num_kv_heads: int
attn_backends: list[type[AttentionBackend]] attn_backends: list[type[AttentionBackend]]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
tensor_shape: torch.Size | None = None tensor_shape: torch.Size | None = None
is_mamba: bool = False
def __post_init__(self): def __post_init__(self):
self.local_physical_heads = max(1, self.total_num_kv_heads // self.tp_size)
self._engines: dict[EngineId, EngineTransferInfo] = {}
self._fa_source_sets: dict[EngineId, frozenset[int]] = {}
self._fa_source_indices: dict[EngineId, dict[int, int]] = {}
# Figure out whether the first dimension of the cache is K/V # Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly. # or num_blocks.
attn_backend = self.attn_backends[0] attn_backend = self.attn_backends[0]
if not self.is_mamba: if not self.is_mamba:
_MOCK_BLOCK_SIZE = 16 _MOCK_BLOCK_SIZE = 16
kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape( kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1 num_blocks=1,
block_size=_MOCK_BLOCK_SIZE,
num_kv_heads=1,
head_size=1,
) )
logger.debug("Test kv_cache_shape: %s", kv_cache_shape) logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
...@@ -358,11 +466,9 @@ class TpKVTopology: ...@@ -358,11 +466,9 @@ class TpKVTopology:
self._cross_layers_blocks = ( self._cross_layers_blocks = (
len(self.tensor_shape) == len(kv_cache_shape) + 1 len(self.tensor_shape) == len(kv_cache_shape) + 1
) )
self.tensor_shape: torch.Size
if self._cross_layers_blocks: if self._cross_layers_blocks:
logger.debug("Using cross-layer KV cache") logger.debug("Using cross-layer KV cache")
# prepend layers dimension
_MOCK_NUM_LAYERS = 80 _MOCK_NUM_LAYERS = 80
kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape
try: try:
...@@ -372,15 +478,81 @@ class TpKVTopology: ...@@ -372,15 +478,81 @@ class TpKVTopology:
except (AttributeError, NotImplementedError): except (AttributeError, NotImplementedError):
assert self.tensor_shape is not None assert self.tensor_shape is not None
kv_cache_stride_order = tuple(range(len(self.tensor_shape))) kv_cache_stride_order = tuple(range(len(self.tensor_shape)))
# In case of cross layers permute kv_cache_shape according to
# stride_order to retrieve physical position of block_size
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
# ============================================================
# Engine registration
# ============================================================
def register_remote_engine(
self,
remote_engine_id: EngineId,
remote_tp_size: int,
remote_block_size: int,
remote_block_len: int,
remote_physical_blocks_per_logical: int,
*,
local_block_len: int = 0,
) -> EngineTransferInfo:
"""Register a remote engine, unifying worker dicts state.
Only remote engines should be registered here — the local engine's
identity (tp_size, block_size, etc.) is set via ``__init__`` params.
For Mamba models, also computes the Mamba transfer plan and
builds the FA source lookup caches.
Args:
local_block_len: Local representative block_len (bytes).
Required for Mamba models to compute ``fa_descriptor_bytes``.
"""
assert remote_engine_id != self.engine_id, (
f"Cannot register local engine {self.engine_id} as remote. "
f"Local identity is set via __init__ params."
)
if remote_engine_id in self._engines:
return self._engines[remote_engine_id]
info: EngineTransferInfo
if self.is_mamba:
info = self._build_mamba_info(
remote_tp_size=remote_tp_size,
remote_block_size=remote_block_size,
remote_block_len=remote_block_len,
remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical),
local_block_len=local_block_len,
)
assert isinstance(info, MambaEngineTransferInfo)
self._fa_source_sets[remote_engine_id] = frozenset(
info.remote_fa_source_ranks
)
self._fa_source_indices[remote_engine_id] = {
r: i for i, r in enumerate(info.remote_fa_source_ranks)
}
else:
info = EngineTransferInfo(
remote_tp_size=remote_tp_size,
remote_block_len=remote_block_len,
remote_block_size=remote_block_size,
remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical),
)
self._engines[remote_engine_id] = info
return info
def get_engine_info(self, remote_engine_id: EngineId) -> EngineTransferInfo:
return self._engines[remote_engine_id]
# ============================================================
# Layout properties
# ============================================================
@property @property
def is_kv_layout_blocks_first(self) -> bool: def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first return self._is_kv_layout_blocks_first
@property
def cross_layers_blocks(self) -> bool:
return self._cross_layers_blocks
@property @property
def split_k_and_v(self) -> bool: def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present). # Whether to register regions for K and V separately (when present).
...@@ -388,29 +560,16 @@ class TpKVTopology: ...@@ -388,29 +560,16 @@ class TpKVTopology:
self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
) )
@property # ============================================================
def tp_size(self) -> int: # Common methods
return self.remote_tp_size[self.engine_id] # ============================================================
@property def tp_ratio(self, remote_tp_size: int) -> int:
def block_size(self) -> int: """Calculate the tensor parallel ratio between local and remote TP.
return self.remote_block_size[self.engine_id]
@property Positive when local_tp >= remote_tp (local workers read from the
def cross_layers_blocks(self) -> bool: same remote worker in groups of size ``tp_ratio``). Negative when
return self._cross_layers_blocks remote_tp > local_tp (ratio is flipped).
def tp_ratio(
self,
remote_tp_size: int,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
""" """
if self.tp_size >= remote_tp_size: if self.tp_size >= remote_tp_size:
assert self.tp_size % remote_tp_size == 0, ( assert self.tp_size % remote_tp_size == 0, (
...@@ -418,78 +577,65 @@ class TpKVTopology: ...@@ -418,78 +577,65 @@ class TpKVTopology:
f"by remote tensor parallel size {remote_tp_size}." f"by remote tensor parallel size {remote_tp_size}."
) )
return self.tp_size // remote_tp_size return self.tp_size // remote_tp_size
assert remote_tp_size % self.tp_size == 0, ( assert remote_tp_size % self.tp_size == 0, (
f"Remote tensor parallel size {remote_tp_size} is not divisible " f"Remote tensor parallel size {remote_tp_size} is not divisible "
f"by local tensor parallel size {self.tp_size}." f"by local tensor parallel size {self.tp_size}."
) )
# P TP > D TP case, return the ratio as negative return -(remote_tp_size // self.tp_size)
return -remote_tp_size // self.tp_size
def block_size_ratio( def block_size_ratio(self, remote_block_size: int) -> int:
self, """Calculate the block size ratio between local and remote."""
remote_block_size: int,
) -> int:
"""
Calculate the block size ratio between local and remote TP.
"""
assert self.block_size % remote_block_size == 0, ( assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible " f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa." f"by remote block size {remote_block_size} or vice versa."
) )
return self.block_size // remote_block_size return self.block_size // remote_block_size
def tp_ratio_from_engine_id( def is_kv_replicated(self, remote_engine_id: EngineId) -> bool:
self, """Whether the KV cache is replicated across TP workers due to the
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads. number of TP workers being greater than the number of KV heads.
When they are equal, each TP rank still owns one distinct KV head,
so this is not considered replication.
""" """
tp_size = self.remote_tp_size[engine_id] return self._engines[remote_engine_id].remote_tp_size > self.total_num_kv_heads
return tp_size > self.total_num_kv_heads
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split. # MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id) return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_ranks( @property
self, def local_replicates_kv_cache(self) -> bool:
remote_tp_size: int, """Whether the local engine's KV cache is replicated."""
) -> list[int]: return self.is_mla or self.tp_size > self.total_num_kv_heads
"""
Get the remote TP rank (on P) that the current local TP rank def handshake_target_ranks(self, remote_tp_size: int) -> list[int]:
(on D) will read from. When remote tp_size > local tp_size, we """Pre-registration: compute which remote TP ranks to handshake with.
read from multiple remote ranks.
Pure math based on local/remote TP sizes — does not require
the remote engine to be registered yet.
""" """
tp_ratio = self.tp_ratio(remote_tp_size) tp_ratio = self.tp_ratio(remote_tp_size)
if tp_ratio > 0: if tp_ratio > 0:
return [self.tp_rank // tp_ratio] return [self.tp_rank // tp_ratio]
abs_ratio = -tp_ratio
return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)]
# P TP > D TP case, D reads from |tp_ratio| remote workers. def target_remote_ranks(self, remote_engine_id: EngineId) -> list[int]:
tp_ratio = -tp_ratio """Get the remote TP rank(s) that the current local TP rank will
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] read from. When remote tp_size > local tp_size, reads from
multiple remote ranks.
def get_target_remote_ranks_from_engine_id( For Mamba models, returns the precomputed ``all_source_ranks``
self, (FA + Mamba union).
remote_engine_id: EngineId, """
) -> list[int]: info = self._engines[remote_engine_id]
remote_tp_size = self.remote_tp_size[remote_engine_id] if isinstance(info, MambaEngineTransferInfo):
return self.get_target_remote_ranks(remote_tp_size) return list(info.remote_all_source_ranks)
tp_ratio = self.tp_ratio(info.remote_tp_size)
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
# remote TP > local TP: read from |tp_ratio| remote workers
abs_ratio = -tp_ratio
return [self.tp_rank * abs_ratio + i for i in range(abs_ratio)]
def get_transfer_cache_regions( def get_transfer_cache_regions(
self, cache: torch.Tensor, layer_spec: "KVCacheSpec" self, cache: torch.Tensor, layer_spec: "KVCacheSpec"
...@@ -498,331 +644,139 @@ class TpKVTopology: ...@@ -498,331 +644,139 @@ class TpKVTopology:
also accounting for hybrid SSM models specificities. also accounting for hybrid SSM models specificities.
""" """
if isinstance(layer_spec, MambaSpec): if isinstance(layer_spec, MambaSpec):
# Register the whole kv cache shared tensor, including SSM/Conv. This is # Register the whole kv cache shared tensor, including
# similar to FI with the difference that SSM/Conv have different sizes # SSM/Conv.
conv, ssm = cache conv, ssm = cache
return [conv] return [conv]
# Check may be hacky but it's matching `_update_hybrid_attention_mamba_layout`. # Check may be hacky but it's matching
# `_update_hybrid_attention_mamba_layout`.
if self.is_mamba and cache.shape[0] == 2: if self.is_mamba and cache.shape[0] == 2:
# When MAMBA is present, all backends are blocks first, so that blocks # When MAMBA is present, all backends are blocks first, so
# can be shared between attention layers and mamba layers. Runner # that blocks can be shared between attention layers and mamba
# `_update_hybrid_attention_mamba_layout` already adjusted strides # layers. Runner already adjusted strides for FlashAttn-like
# for FlashAttn-like backends so its num_blocks first. # backends so its num_blocks first.
# Swap [2<>num_blocks] dims to get required layout for hybrid SSM. # Swap [2<>num_blocks] dims for hybrid SSM layout.
cache = cache.transpose(0, 1) cache = cache.transpose(0, 1)
# Regular case: backends like FA register K/V in separate regions # Regular case: backends like FA register K/V in separate regions
return cache if self.split_k_and_v else [cache] return cache if self.split_k_and_v else [cache]
# ============================================================
# Mamba-specific methods
# ============================================================
# ---- Mamba-HMA hetero-TP transfer config ---- def should_skip_fa(self, remote_engine_id: EngineId, remote_rank: int) -> bool:
# """Whether to skip FA groups for this remote rank (mamba-only)."""
# Key insight: with hetero-TP (P_TP > D_TP), FA KV cache may be return remote_rank not in self._fa_source_sets[remote_engine_id]
# replicated across P ranks (when P_TP > num_kv_heads), but Mamba
# conv/SSM state is almost always uniquely sharded per P rank. So the
# number of P ranks D must read from can differ between FA and Mamba,
# and they must be handled separately.
def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range:
"""Physical KV head range stored in a rank's KV cache tensor.
When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank.
When ``tp_size > num_heads``: 1 physical head per rank. Heads are
distributed **contiguously** (matching vLLM's GQA weight partitioning):
consecutive ranks share a head before moving to the next one.
"""
if tp_size <= num_heads:
assert num_heads % tp_size == 0
per_rank = num_heads // tp_size
return range(rank * per_rank, (rank + 1) * per_rank)
else:
h = rank * num_heads // tp_size
return range(h, h + 1)
def _range_overlap(a: range, b: range) -> range:
start = max(a.start, b.start)
stop = min(a.stop, b.stop)
return range(start, max(start, stop))
@dataclass def fa_head_slot(self, remote_engine_id: EngineId, remote_rank: int) -> int:
class HeteroTPTransferConfig: """Index into local FA block for this remote rank's head data.
"""Precomputed transfer plan for one (D rank, P engine) pair.
Currently only instantiated for Mamba-HMA (hybrid SSM+Attention) models
where FA and mamba require different splitting factors. Could be extended
to other model types that need non-uniform hetero-TP transfer sizing.
All descriptor sizes are computed here. The guarantee is:
local_entry_size == remote_entry_size (for NIXL)
Attributes that start with ``fa_`` concern FlashAttention KV cache.
Attributes that start with ``mamba_`` concern Mamba conv/SSM state.
"""
# ---- Input parameters (from handshake) ----
tp_ratio: int
K: int # total_num_kv_heads (before TP sharding)
d_tp: int # D engine's tensor_parallel_size
p_tp: int # P engine's tensor_parallel_size
d_rank: int # this D worker's TP rank
use_mla: bool
# Per-layer block lengths (bytes, K+V combined for blocks_first).
# Uniform across layers for current models.
d_block_len: int # D's block_len_per_layer (representative)
p_block_len: int # P's block_len_per_layer (from handshake)
is_blocks_first: bool # kv_topo.is_kv_layout_blocks_first
# ---- Derived: computed in __post_init__ ----
#
# Physical heads per rank (what the KV tensor actually stores)
d_physical_heads: int = field(init=False)
p_physical_heads: int = field(init=False)
# How many distinct P ranks D needs for FA data
physical_fa_num_reads: int = field(init=False)
# Which P ranks contribute unique FA heads (ordered by head index)
fa_read_targets: list[int] = field(init=False)
# All P ranks needed for mamba (always abs_tp for tp_ratio < 0)
mamba_num_reads: int = field(init=False)
# All P ranks this D rank communicates with (FA ∪ mamba)
transfer_targets: list[int] = field(init=False)
# FA descriptor entry size (K or V side, for blocks_first layout)
# Guaranteed: fa_entry_size is the SAME for local handle AND remote desc.
fa_entry_size: int = field(init=False)
# Replication flags
is_d_replicated: bool = field(init=False)
is_p_replicated: bool = field(init=False)
# Pre-built set for fast lookup
_fa_target_set: frozenset[int] = field(init=False, repr=False)
# Map: P rank → index in fa_read_targets (for head slot offset)
_fa_target_index: dict[int, int] = field(init=False, repr=False)
def __post_init__(self) -> None:
K = self.K
self.is_d_replicated = self.d_tp > K
self.is_p_replicated = self.p_tp > K
self.d_physical_heads = max(1, K // self.d_tp)
self.p_physical_heads = max(1, K // self.p_tp)
abs_tp = -self.tp_ratio if self.tp_ratio < 0 else 1
# ---- Mamba range (computed first so FA can prefer ranks in it) ----
mamba_range: range | None = None
if self.tp_ratio < 0:
mamba_range = range(self.d_rank * abs_tp, (self.d_rank + 1) * abs_tp)
# ---- FA read targets ----
if self.use_mla or self.tp_ratio >= 0:
self.physical_fa_num_reads = 1
self.fa_read_targets = (
[0]
if self.use_mla
# Must match kv_topo.get_target_remote_ranks (d_rank // tp_ratio).
else [
self.d_rank // self.tp_ratio if self.tp_ratio > 0 else self.d_rank
]
)
else:
d_needs = _physical_head_range(self.d_tp, K, self.d_rank)
# When mamba range exists, prefer P ranks within it so that
# FA targets are a subset of mamba transfer_targets (avoids
# orphaned FA targets outside the transfer loop).
search_range = mamba_range if mamba_range is not None else range(self.p_tp)
seen: set[tuple[int, int]] = set()
targets: list[int] = []
for p in search_range:
p_has = _physical_head_range(self.p_tp, K, p)
ov = _range_overlap(d_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
targets.append(p)
if not targets:
# Fallback: search globally (should not happen in practice)
for p in range(self.p_tp):
p_has = _physical_head_range(self.p_tp, K, p)
ov = _range_overlap(d_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
targets.append(p)
self.fa_read_targets = targets
self.physical_fa_num_reads = len(targets)
self._fa_target_set = frozenset(self.fa_read_targets)
self._fa_target_index = {r: i for i, r in enumerate(self.fa_read_targets)}
# ---- Mamba targets ----
if mamba_range is not None and abs_tp > self.physical_fa_num_reads:
self.mamba_num_reads = abs_tp
self.transfer_targets = list(mamba_range)
else:
self.mamba_num_reads = self.physical_fa_num_reads
self.transfer_targets = list(self.fa_read_targets)
# ---- FA entry size ----
# For blocks_first: block_len_per_layer includes K+V; // 2 gives K (or V).
# Use min(D, P) because D indexes into P when tp_ratio > 0,
# and P is the natural unit when tp_ratio < 0.
effective_block_len = min(self.d_block_len, self.p_block_len)
if self.is_blocks_first:
self.fa_entry_size = effective_block_len // 2
else:
self.fa_entry_size = effective_block_len
self._validate()
def _validate(self) -> None:
"""Cross-check internal consistency."""
if self.is_d_replicated and self.is_p_replicated and self.tp_ratio > 0:
logger.info(
"Both-replicated hetero-TP: D_TP=%d > P_TP=%d > K=%d. "
"Using d_rank // tp_ratio routing with relative head offset.",
self.d_tp,
self.p_tp,
self.K,
)
# FA targets must be a subset of transfer_targets
tt_set = set(self.transfer_targets)
for t in self.fa_read_targets:
if t not in tt_set:
logger.error(
"FA target P rank %d is NOT in transfer_targets %s. "
"This will cause missed FA reads!",
t,
self.transfer_targets,
)
# For tp_ratio < 0 with blocks_first: D_K_half / reads should == P_K_half
if (
self.is_blocks_first
and self.tp_ratio < 0
and self.physical_fa_num_reads > 0
):
d_k_half = self.d_block_len // 2
p_k_half = self.p_block_len // 2
expected_local = d_k_half // self.physical_fa_num_reads
if expected_local != p_k_half:
logger.warning(
"FA size mismatch: D_K_half=%d / reads=%d = %d, "
"but P_K_half=%d. This may indicate a head count or "
"Mamba-HMA inflation inconsistency.",
d_k_half,
self.physical_fa_num_reads,
expected_local,
p_k_half,
)
# ---- Query methods ---- For remote ranks in ``fa_source_ranks``, returns 0, 1, …, reads-1.
For ranks NOT in ``fa_source_ranks`` (replicated duplicates),
def should_skip_fa(self, p_rank: int) -> bool: returns the slot of the matching source rank with the same head.
"""Whether to skip FA groups for this P rank (mamba-only transfer)."""
return p_rank not in self._fa_target_set
def fa_head_slot(self, p_rank: int) -> int:
"""Index into D's FA block for this P rank's head data.
For P ranks in fa_read_targets, returns 0, 1, ..., reads-1.
For P ranks NOT in fa_read_targets (replicated duplicates),
returns the slot of the matching FA target with the same head.
""" """
if p_rank in self._fa_target_index: fa_index = self._fa_source_indices[remote_engine_id]
return self._fa_target_index[p_rank] if remote_rank in fa_index:
# Duplicate head: find which fa_target has the same physical head return fa_index[remote_rank]
p_head = _physical_head_range(self.p_tp, self.K, p_rank) mamba_info = self._engines[remote_engine_id]
for target in self.fa_read_targets: assert isinstance(mamba_info, MambaEngineTransferInfo)
t_head = _physical_head_range(self.p_tp, self.K, target) K = self.total_num_kv_heads
if _range_overlap(p_head, t_head): remote_tp = mamba_info.remote_tp_size
return self._fa_target_index[target] r_head = self._physical_head_range(remote_tp, K, remote_rank)
return 0 # fallback for target in mamba_info.remote_fa_source_ranks:
t_head = self._physical_head_range(remote_tp, K, target)
def fa_rank_offset(self, remote_kv_block_len: int) -> int: if self._range_overlap(r_head, t_head):
"""Byte offset into P's FA block for this D rank. return fa_index[target]
return 0
When D is replicated (D_TP > K), multiple D ranks share a head.
Computes offset *relative to the target P rank's first head* def fa_rank_offset(
so it works regardless of how many heads P has. self, remote_engine_id: EngineId, remote_kv_block_len: int
When neither side replicates, falls back to tp_rank % tp_ratio. ) -> int:
Returns 0 when D does not index into P's block. """Byte offset into remote FA block for this local rank.
When local TP is replicated (local_tp > K), multiple local ranks
share a head. Computes offset *relative to the target remote
rank's first head* so it works regardless of how many heads the
remote has. Returns 0 when local does not index into remote.
""" """
if self.use_mla or self.tp_ratio <= 0: mamba_info = self._engines[remote_engine_id]
assert isinstance(mamba_info, MambaEngineTransferInfo)
tp_ratio = self.tp_ratio(mamba_info.remote_tp_size)
if self.is_mla or tp_ratio <= 0:
return 0 return 0
if self.is_d_replicated: K = self.total_num_kv_heads
d_head = self.d_rank * self.K // self.d_tp is_local_replicated = self.tp_size > K
p_rank = self.fa_read_targets[0] if is_local_replicated:
p_start = p_rank * self.K // self.p_tp local_head = self.tp_rank * K // self.tp_size
return (d_head - p_start) * remote_kv_block_len p_rank = mamba_info.remote_fa_source_ranks[0]
return self.d_rank % self.tp_ratio * remote_kv_block_len p_start = p_rank * K // mamba_info.remote_tp_size
return (local_head - p_start) * remote_kv_block_len
@property return self.tp_rank % tp_ratio * remote_kv_block_len
def needs_split_handles(self) -> bool:
"""Whether per-P-rank split handles are needed. def needs_split_handles(self, remote_engine_id: EngineId) -> bool:
"""Whether per-remote-rank split handles are needed.
True when FA and mamba have different read counts, requiring True when FA and mamba have different read counts, requiring
different splitting factors in the local handle. different splitting factors in the local handle.
""" """
return self.tp_ratio < 0 and not self.use_mla and len(self.transfer_targets) > 1 mamba_info = self._engines[remote_engine_id]
assert isinstance(mamba_info, MambaEngineTransferInfo)
tp_ratio = self.tp_ratio(mamba_info.remote_tp_size)
return (
tp_ratio < 0
and not self.is_mla
and len(mamba_info.remote_all_source_ranks) > 1
)
def compute_split_handle_data( def compute_split_handle_data(
self, self,
remote_engine_id: EngineId,
src_blocks_data: list[tuple[int, int, int]], src_blocks_data: list[tuple[int, int, int]],
num_fa_descs: int, num_fa_descs: int,
abs_tp: int, abs_tp: int,
) -> list[list[tuple[int, int, int]]]: ) -> list[list[tuple[int, int, int]]]:
"""Compute per-P-rank (addr, len, tp) triples for Mamba-HMA split handles. """Per-remote-rank (addr, len, dev) triples for Mamba-HMA split
handles.
FA descriptors (indices < num_fa_descs) are sliced by FA descriptors (indices < num_fa_descs) are sliced by
``physical_fa_num_reads``; mamba descriptors are sliced uniformly ``remote_num_fa_reads``; mamba descriptors are sliced uniformly
by ``abs_tp``. by ``abs_tp``.
Returns one list of triples per transfer target.
""" """
mamba_info = self._engines[remote_engine_id]
assert isinstance(mamba_info, MambaEngineTransferInfo)
all_handle_data: list[list[tuple[int, int, int]]] = [] all_handle_data: list[list[tuple[int, int, int]]] = []
for p_idx, p_rank in enumerate(self.transfer_targets): for p_idx, p_rank in enumerate(mamba_info.remote_all_source_ranks):
handle_data: list[tuple[int, int, int]] = [] handle_data: list[tuple[int, int, int]] = []
skip_fa = self.should_skip_fa(p_rank) skip_fa = self.should_skip_fa(remote_engine_id, p_rank)
fa_slot = self.fa_head_slot(p_rank) if not skip_fa else 0 fa_slot = self.fa_head_slot(remote_engine_id, p_rank) if not skip_fa else 0
for j, (addr, local_len, dev) in enumerate(src_blocks_data):
for j, (addr, local_len, tp) in enumerate(src_blocks_data):
if j < num_fa_descs: if j < num_fa_descs:
assert self.physical_fa_num_reads >= 1 assert mamba_info.remote_num_fa_reads >= 1
fa_chunk = local_len // self.physical_fa_num_reads fa_chunk = local_len // mamba_info.remote_num_fa_reads
handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, tp)) handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, dev))
else: else:
mamba_chunk = local_len // abs_tp mamba_chunk = local_len // abs_tp
handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, tp)) handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, dev))
all_handle_data.append(handle_data) all_handle_data.append(handle_data)
return all_handle_data return all_handle_data
def filter_block_ids_for_rank( def filter_block_ids_for_rank(
self, self,
remote_engine_id: EngineId,
remote_rank: int, remote_rank: int,
local_ids: BlockIds, local_ids: BlockIds,
remote_ids: BlockIds, remote_ids: BlockIds,
is_mamba_group: list[bool], is_mamba_group: list[bool],
) -> tuple[BlockIds, BlockIds]: ) -> tuple[BlockIds, BlockIds]:
"""Zero out FA groups for P ranks outside fa_read_targets. """Zero out FA groups for remote ranks outside ``fa_source_ranks``.
Returns (filtered_local_ids, filtered_remote_ids). When the Returns (filtered_local_ids, filtered_remote_ids). When the
remote rank carries FA data for this D rank, returns the inputs remote rank carries FA data for this local rank, returns the
unchanged. inputs unchanged.
""" """
if not self.should_skip_fa(remote_rank): if not self.should_skip_fa(remote_engine_id, remote_rank):
return local_ids, remote_ids return local_ids, remote_ids
num_groups = len(local_ids) num_groups = len(local_ids)
filtered_local: list[list[int]] = [ filtered_local: list[list[int]] = [
...@@ -833,108 +787,184 @@ class HeteroTPTransferConfig: ...@@ -833,108 +787,184 @@ class HeteroTPTransferConfig:
] ]
return filtered_local, filtered_remote return filtered_local, filtered_remote
def describe(self) -> str: def describe(self, remote_engine_id: EngineId) -> str:
"""One-line summary for logging.""" """One-line summary of transfer config for logging."""
return ( info = self._engines[remote_engine_id]
f"HeteroTPTransferConfig(" base = (
f"tp_ratio={self.tp_ratio}, K={self.K}, " f"tp_ratio={self.tp_ratio(info.remote_tp_size)}, "
f"d_tp={self.d_tp}, p_tp={self.p_tp}, d_rank={self.d_rank}, " f"K={self.total_num_kv_heads}, "
f"physical_fa_reads={self.physical_fa_num_reads}, " f"local_tp={self.tp_size}, "
f"mamba_reads={self.mamba_num_reads}, " f"remote_tp={info.remote_tp_size}, "
f"fa_targets={self.fa_read_targets}, " f"local_rank={self.tp_rank}, "
f"transfer_targets={self.transfer_targets}, " f"remote_block_len={info.remote_block_len}"
f"fa_entry_size={self.fa_entry_size}, "
f"d_block_len={self.d_block_len}, p_block_len={self.p_block_len})"
) )
if isinstance(info, MambaEngineTransferInfo):
return (
f"TransferTopology.mamba({base}, "
f"fa_reads={info.remote_num_fa_reads}, "
f"mamba_reads={info.remote_num_mamba_reads}, "
f"fa_sources={list(info.remote_fa_source_ranks)}, "
f"all_sources={list(info.remote_all_source_ranks)}, "
f"fa_desc_bytes={info.remote_fa_descriptor_bytes})"
)
return f"TransferTopology({base})"
# ============================================================
# Private helpers
# ============================================================
# Mamba-HMA hetero-TP transfer config:
# With hetero-TP (P_TP > D_TP), FA KV cache may be replicated across
# P ranks (when P_TP > num_kv_heads), but Mamba conv/SSM state is
# almost always uniquely sharded per P rank. So the number of P
# ranks D must read from can differ between FA and Mamba, and they
# must be handled separately.
@staticmethod
def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range:
"""Physical KV head range stored in a rank's KV cache tensor.
When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank.
When ``tp_size > num_heads``: 1 physical head per rank. Heads are
distributed **contiguously** (matching vLLM's GQA weight partitioning):
consecutive ranks share a head before moving to the next one.
"""
if tp_size <= num_heads:
assert num_heads % tp_size == 0
per_rank = num_heads // tp_size
return range(rank * per_rank, (rank + 1) * per_rank)
else:
h = rank * num_heads // tp_size
return range(h, h + 1)
@staticmethod
def _range_overlap(a: range, b: range) -> range:
start = max(a.start, b.start)
stop = min(a.stop, b.stop)
return range(start, max(start, stop))
def get_current_attn_backends( # ============================================================
vllm_config: VllmConfig, layer_names: list[str] | None = None # Private: build Mamba transfer info
) -> list[type[AttentionBackend]]: # ============================================================
"""Get all distinct attention backends for the given layers.
Args: def _build_mamba_info(
vllm_config: The current vLLM configuration. self,
layer_names: Optional list of layer names to scope the lookup. remote_tp_size: int,
When None, all attention layers are considered. remote_block_size: int,
remote_block_len: int,
remote_physical_blocks_per_logical: int,
local_block_len: int,
) -> MambaEngineTransferInfo:
"""Compute Mamba transfer plan."""
K = self.total_num_kv_heads
local_tp = self.tp_size
local_rank = self.tp_rank
is_remote_replicated = remote_tp_size > K
remote_physical_heads = max(1, K // remote_tp_size)
if local_tp >= remote_tp_size:
assert local_tp % remote_tp_size == 0
tp_ratio = local_tp // remote_tp_size
else:
assert remote_tp_size % local_tp == 0
tp_ratio = -(remote_tp_size // local_tp)
Returns: abs_tp = -tp_ratio if tp_ratio < 0 else 1
Deduplicated list of attention backend classes.
"""
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
if layers:
seen: dict[str, type[AttentionBackend]] = {}
for layer in layers.values():
backend = layer.get_attn_backend()
seen[backend.full_cls_name()] = backend
return list(seen.values())
# Fallback for tests, when static_forward_context is empty. mamba_range: range | None = None
logger.debug( if tp_ratio < 0:
"No layers found in the vLLM config. Falling back to default attention backend." mamba_range = range(local_rank * abs_tp, (local_rank + 1) * abs_tp)
)
from vllm.v1.attention.selector import get_attn_backend
return [ # ---- FA read targets ----
get_attn_backend( if self.is_mla or tp_ratio >= 0:
head_size=vllm_config.model_config.get_head_size(), num_fa_reads = 1
dtype=vllm_config.model_config.dtype, fa_source_ranks: list[int] = (
kv_cache_dtype=vllm_config.cache_config.cache_dtype, [0]
use_mla=vllm_config.model_config.use_mla, if self.is_mla
) else [local_rank // tp_ratio if tp_ratio > 0 else local_rank]
] )
else:
local_needs = self._physical_head_range(local_tp, K, local_rank)
search_range = (
mamba_range if mamba_range is not None else range(remote_tp_size)
)
seen: set[tuple[int, int]] = set()
fa_source_ranks = []
for p in search_range:
p_has = self._physical_head_range(remote_tp_size, K, p)
ov = self._range_overlap(local_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
fa_source_ranks.append(p)
if not fa_source_ranks:
for p in range(remote_tp_size):
p_has = self._physical_head_range(remote_tp_size, K, p)
ov = self._range_overlap(local_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
fa_source_ranks.append(p)
num_fa_reads = len(fa_source_ranks)
# ---- All source ranks (mamba + FA) ----
if mamba_range is not None and abs_tp > num_fa_reads:
num_mamba_reads = abs_tp
all_source_ranks = list(mamba_range)
else:
num_mamba_reads = num_fa_reads
all_source_ranks = list(fa_source_ranks)
def get_current_attn_backend( # ---- FA descriptor bytes ----
vllm_config: VllmConfig, layer_names: list[str] | None = None effective_block_len = min(local_block_len, remote_block_len)
) -> type[AttentionBackend]: if self.is_kv_layout_blocks_first:
"""Get the first attention backend for the given layers.""" fa_descriptor_bytes = effective_block_len // 2
return get_current_attn_backends(vllm_config, layer_names)[0] else:
fa_descriptor_bytes = effective_block_len
# ---- Validation ----
is_local_replicated = local_tp > K
if is_local_replicated and is_remote_replicated and tp_ratio > 0:
logger.info(
"Both-replicated hetero-TP: local_tp=%d > remote_tp=%d > K=%d.",
local_tp,
remote_tp_size,
K,
)
tt_set = set(all_source_ranks)
for t in fa_source_ranks:
if t not in tt_set:
logger.error(
"FA source rank %d NOT in all_source_ranks %s.",
t,
all_source_ranks,
)
if self.is_kv_layout_blocks_first and tp_ratio < 0 and num_fa_reads > 0:
local_k_half = local_block_len // 2
remote_k_half = remote_block_len // 2
expected = local_k_half // num_fa_reads
if expected != remote_k_half:
logger.warning(
"FA size mismatch: local_k_half=%d / reads=%d = %d, "
"but remote_k_half=%d.",
local_k_half,
num_fa_reads,
expected,
remote_k_half,
)
# TODO (ZhanqiuHu): Consolidate TpKVTopology and HeteroTPTransferConfig return MambaEngineTransferInfo(
# into a single engine-agnostic TransferTopology class. remote_tp_size=remote_tp_size,
# 6 of 9 HeteroTPTransferConfig init fields duplicate TpKVTopology data. remote_block_len=remote_block_len,
# remote_block_size=remote_block_size,
# @dataclass remote_physical_blocks_per_logical=(remote_physical_blocks_per_logical),
# class EngineTransferInfo: remote_fa_source_ranks=tuple(fa_source_ranks),
# """Per-remote-engine transfer state, computed at handshake.""" remote_all_source_ranks=tuple(all_source_ranks),
# p_tp: int remote_num_fa_reads=num_fa_reads,
# tp_ratio: int remote_num_mamba_reads=num_mamba_reads,
# p_block_len: int remote_fa_descriptor_bytes=fa_descriptor_bytes,
# block_size: int is_remote_replicated=is_remote_replicated,
# # Mamba-specific (None for non-mamba models) remote_physical_heads=remote_physical_heads,
# fa_read_targets: list[int] | None = None )
# transfer_targets: list[int] | None = None
# physical_fa_num_reads: int | None = None
# mamba_num_reads: int | None = None
# fa_entry_size: int | None = None
#
# class TransferTopology:
# """Single source of truth for TP topology + transfer sizing."""
# # Shared (set once at init, replaces duplicate fields)
# tp_rank: int # == TpKVTopology.tp_rank == HeteroTP.d_rank
# tp_size: int # == TpKVTopology.tp_size == HeteroTP.d_tp
# total_num_kv_heads: int # == HeteroTP.K
# is_mla: bool # == HeteroTP.use_mla
# is_mamba: bool
# is_blocks_first: bool # == HeteroTP.is_blocks_first
# d_block_len: int
#
# # Per-engine (populated via register_engine() at handshake)
# _engines: dict[EngineId, EngineTransferInfo]
#
# def register_engine(self, engine_id, p_tp, p_block_len, ...): ...
#
# # General (from TpKVTopology)
# def tp_ratio(self, engine_id) -> int: ...
# def target_remote_ranks(self, engine_id) -> list[int]: ...
# def is_kv_replicated(self, engine_id) -> bool: ...
#
# # Mamba-specific (from HeteroTPTransferConfig, gated by is_mamba)
# def fa_rank_offset(self, engine_id, block_len) -> int: ...
# def physical_fa_num_reads(self, engine_id) -> int: ...
# def transfer_targets(self, engine_id) -> list[int]: ...
# def should_skip_fa(self, engine_id, p_rank) -> bool: ...
# def filter_block_ids_for_rank(self, engine_id, ...) -> ...: ...
...@@ -21,7 +21,7 @@ from vllm import envs ...@@ -21,7 +21,7 @@ from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId, EngineId,
TpKVTopology, TransferTopology,
get_current_attn_backend, get_current_attn_backend,
get_current_attn_backends, get_current_attn_backends,
) )
...@@ -764,13 +764,13 @@ class MooncakeConnectorWorker: ...@@ -764,13 +764,13 @@ class MooncakeConnectorWorker:
logger.debug("Detected kv cache layout %s", self.kv_cache_layout) logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size} self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} self.transfer_topo = TransferTopology(
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.tp_size,
block_size=self.block_size,
engine_id=self.engine_id, engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla, is_mla=self.use_mla,
is_mamba=False,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(), total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=[backend], attn_backends=[backend],
) )
...@@ -911,7 +911,7 @@ class MooncakeConnectorWorker: ...@@ -911,7 +911,7 @@ class MooncakeConnectorWorker:
self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata
): ):
pending_reqs: dict[ReqId, SendBlockMeta] = {} pending_reqs: dict[ReqId, SendBlockMeta] = {}
remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size) remote_tp_ranks = self.transfer_topo.handshake_target_ranks(meta.remote_tp_size)
if meta.remote_tp_rank not in remote_tp_ranks: if meta.remote_tp_rank not in remote_tp_ranks:
# This D worker does not pair with the P worker. # This D worker does not pair with the P worker.
msg = ( msg = (
...@@ -1256,7 +1256,7 @@ class MooncakeConnectorWorker: ...@@ -1256,7 +1256,7 @@ class MooncakeConnectorWorker:
seen_base_addresses = [] seen_base_addresses = []
self.block_len_per_layer = [] self.block_len_per_layer = []
split_k_and_v = self.kv_topo.split_k_and_v split_k_and_v = self.transfer_topo.split_k_and_v
tensor_size_bytes = None tensor_size_bytes = None
for layer_name, cache_or_caches in kv_caches.items(): for layer_name, cache_or_caches in kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
...@@ -1495,8 +1495,8 @@ class MooncakeConnectorWorker: ...@@ -1495,8 +1495,8 @@ class MooncakeConnectorWorker:
remote_engine_id: EngineId, remote_engine_id: EngineId,
pull_metas: dict[ReqId, PullReqMeta], pull_metas: dict[ReqId, PullReqMeta],
): ):
remote_tp_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( remote_tp_ranks = self.transfer_topo.handshake_target_ranks(
remote_engine_id self._tp_size[remote_engine_id]
) )
count = len(remote_tp_ranks) count = len(remote_tp_ranks)
logger.debug( logger.debug(
...@@ -1587,7 +1587,7 @@ class MooncakeConnectorWorker: ...@@ -1587,7 +1587,7 @@ class MooncakeConnectorWorker:
) )
def _producer_cache_is_replicated(self) -> bool: def _producer_cache_is_replicated(self) -> bool:
return self.kv_topo.replicates_kv_cache(self.engine_id) return self.transfer_topo.local_replicates_kv_cache
def _get_transfer_regions( def _get_transfer_regions(
self, base_addrs: list[int], block_lens: list[int] self, base_addrs: list[int], block_lens: list[int]
...@@ -1595,7 +1595,7 @@ class MooncakeConnectorWorker: ...@@ -1595,7 +1595,7 @@ class MooncakeConnectorWorker:
return _expand_transfer_regions( return _expand_transfer_regions(
base_addrs=base_addrs, base_addrs=base_addrs,
block_lens=block_lens, block_lens=block_lens,
is_kv_layout_blocks_first=self.kv_topo.is_kv_layout_blocks_first, is_kv_layout_blocks_first=self.transfer_topo.is_kv_layout_blocks_first,
) )
def _get_sender_transfer_plan( def _get_sender_transfer_plan(
......
...@@ -21,8 +21,8 @@ from vllm import envs ...@@ -21,8 +21,8 @@ from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
BlockIds, BlockIds,
EngineId, EngineId,
HeteroTPTransferConfig, MambaEngineTransferInfo,
TpKVTopology, TransferTopology,
get_current_attn_backends, get_current_attn_backends,
kv_postprocess_blksize_and_layout_on_receive, kv_postprocess_blksize_and_layout_on_receive,
kv_postprocess_blksize_on_receive, kv_postprocess_blksize_on_receive,
...@@ -49,7 +49,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import ( ...@@ -49,7 +49,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import (
) )
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
MambaConvSplitInfo, MambaConvSplitInfo,
compute_mamba_phys_ratio, compute_physical_blocks_per_logical,
derive_mamba_conv_split, derive_mamba_conv_split,
) )
from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config from vllm.distributed.nixl_utils import NixlWrapper, nixl_agent_config
...@@ -269,14 +269,12 @@ class NixlConnectorWorker: ...@@ -269,14 +269,12 @@ class NixlConnectorWorker:
self._registered_descs: list[Any] = [] self._registered_descs: list[Any] = []
# ---- Mamba-HMA per-engine state (only used when self._has_mamba) ---- # ---- Mamba-HMA per-engine state (only used when self._has_mamba) ----
# Per-engine transfer config (source of truth for FA/mamba sizing). # NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine.
self._transfer_configs: dict[str, HeteroTPTransferConfig] = {} # physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len)
# NOTE (ZhanqiuHu): _mamba_phys_ratio MUST be per-engine.
# compute_mamba_phys_ratio = ceil((conv_bytes + ssm_bytes) / block_len)
# where conv/ssm bytes are per-TP-rank (dimension-sharded). With # where conv/ssm bytes are per-TP-rank (dimension-sharded). With
# heterogeneous TP the per-rank sizes differ, so the ratio differs: # heterogeneous TP the per-rank sizes differ, so the ratio differs:
# e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261. # e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261.
self._mamba_phys_ratio: dict[EngineId, int] = {} self._physical_blocks_per_logical: dict[EngineId, int] = {}
# In progress transfers. # In progress transfers.
# [req_id -> list[handle]] # [req_id -> list[handle]]
...@@ -322,10 +320,8 @@ class NixlConnectorWorker: ...@@ -322,10 +320,8 @@ class NixlConnectorWorker:
# lazy initialized in register_kv_caches # lazy initialized in register_kv_caches
self.compat_hash: str | None = None self.compat_hash: str | None = None
self.kv_topo: TpKVTopology | None = None self.transfer_topo: TransferTopology | None = None
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to # With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks. # finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
...@@ -355,7 +351,6 @@ class NixlConnectorWorker: ...@@ -355,7 +351,6 @@ class NixlConnectorWorker:
self.block_size // kernel_block_size self.block_size // kernel_block_size
) )
self.block_size = kernel_block_size self.block_size = kernel_block_size
self._block_size[self.engine_id] = kernel_block_size
self.num_blocks *= self._physical_blocks_per_logical_kv_block self.num_blocks *= self._physical_blocks_per_logical_kv_block
def _nixl_handshake( def _nixl_handshake(
...@@ -384,8 +379,8 @@ class NixlConnectorWorker: ...@@ -384,8 +379,8 @@ class NixlConnectorWorker:
# Regardless, only handshake with the remote TP rank(s) that current # Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP, # local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i. # this happens to be the same single rank_i.
assert self.kv_topo is not None assert self.transfer_topo is not None
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size) p_remote_ranks = self.transfer_topo.handshake_target_ranks(remote_tp_size)
remote_rank_to_agent_name = {} remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port) path = make_zmq_path("tcp", host, port)
...@@ -649,11 +644,11 @@ class NixlConnectorWorker: ...@@ -649,11 +644,11 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl.""" """Register the KV Cache data in nixl."""
self.kv_topo = TpKVTopology( self.transfer_topo = TransferTopology(
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.world_size,
block_size=self.block_size,
engine_id=self.engine_id, engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla, is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(), total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backends=self.attn_backends, attn_backends=self.attn_backends,
...@@ -664,7 +659,7 @@ class NixlConnectorWorker: ...@@ -664,7 +659,7 @@ class NixlConnectorWorker:
is_mamba=self._has_mamba, is_mamba=self._has_mamba,
) )
self.compat_hash = compute_nixl_compatibility_hash( self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks self.vllm_config, self.backend_name, self.transfer_topo.cross_layers_blocks
) )
if self.use_host_buffer: if self.use_host_buffer:
...@@ -716,7 +711,7 @@ class NixlConnectorWorker: ...@@ -716,7 +711,7 @@ class NixlConnectorWorker:
if isinstance(layer_spec, UniformTypeKVCacheSpecs): if isinstance(layer_spec, UniformTypeKVCacheSpecs):
# MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs # MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs
layer_spec = layer_spec.kv_cache_specs[layer_name] layer_spec = layer_spec.kv_cache_specs[layer_name]
cache_list = self.kv_topo.get_transfer_cache_regions( cache_list = self.transfer_topo.get_transfer_cache_regions(
cache_or_caches, layer_spec cache_or_caches, layer_spec
) )
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is # `layer_spec.page_size_bytes` only accounts for logical page_size, that is
...@@ -729,7 +724,7 @@ class NixlConnectorWorker: ...@@ -729,7 +724,7 @@ class NixlConnectorWorker:
) )
# For when registering multiple tensors eg K/V in separate regions. # For when registering multiple tensors eg K/V in separate regions.
physical_page_size = physical_page_size // len(cache_list) physical_page_size = physical_page_size // len(cache_list)
if self.kv_topo._cross_layers_blocks: if self.transfer_topo._cross_layers_blocks:
# When cross-layers blocks are used, multiply by number of layers # When cross-layers blocks are used, multiply by number of layers
physical_page_size = physical_page_size * len( physical_page_size = physical_page_size * len(
self.kv_cache_config.kv_cache_tensors self.kv_cache_config.kv_cache_tensors
...@@ -793,7 +788,7 @@ class NixlConnectorWorker: ...@@ -793,7 +788,7 @@ class NixlConnectorWorker:
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data) self.num_regions = len(caches_data)
if self.kv_topo.is_kv_layout_blocks_first: if self.transfer_topo.is_kv_layout_blocks_first:
# NOTE (NickLucche) When FlashInfer is used, memory is registered # NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in # with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to # registerMem allowing faster descs queries. In order to be able to
...@@ -817,7 +812,7 @@ class NixlConnectorWorker: ...@@ -817,7 +812,7 @@ class NixlConnectorWorker:
self.dst_num_blocks[self.engine_id] = self.num_blocks self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._has_mamba: if self._has_mamba:
self._mamba_phys_ratio[self.engine_id] = ( self._physical_blocks_per_logical[self.engine_id] = (
self._physical_blocks_per_logical_kv_block self._physical_blocks_per_logical_kv_block
) )
logger.info( logger.info(
...@@ -876,11 +871,13 @@ class NixlConnectorWorker: ...@@ -876,11 +871,13 @@ class NixlConnectorWorker:
conv_offsets = self._conv_decomp.local_conv_offsets conv_offsets = self._conv_decomp.local_conv_offsets
conv_size, ssm_size = self._mamba_ssm_size conv_size, ssm_size = self._mamba_ssm_size
num_blocks = self._logical_num_blocks * block_size_ratio num_blocks = self._logical_num_blocks * block_size_ratio
phys_ratio = self._physical_blocks_per_logical_kv_block physical_per_logical = self._physical_blocks_per_logical_kv_block
result: list[tuple[int, int, int]] = [] result: list[tuple[int, int, int]] = []
for i, base_addr in enumerate(base_addresses): for i, base_addr in enumerate(base_addresses):
page_stride = self.block_len_per_layer[i] // block_size_ratio * phys_ratio page_stride = (
self.block_len_per_layer[i] // block_size_ratio * physical_per_logical
)
for off, sz in conv_offsets: for off, sz in conv_offsets:
for blk in range(num_blocks): for blk in range(num_blocks):
result.append( result.append(
...@@ -900,14 +897,14 @@ class NixlConnectorWorker: ...@@ -900,14 +897,14 @@ class NixlConnectorWorker:
def _build_fa_remote_for_mamba( def _build_fa_remote_for_mamba(
self, self,
nixl_agent_meta: NixlAgentMetadata, nixl_agent_meta: NixlAgentMetadata,
transfer_cfg: HeteroTPTransferConfig,
block_size_ratio: int, block_size_ratio: int,
kv_topo: TpKVTopology, transfer_topo: TransferTopology,
remote_engine_id: EngineId,
) -> list[tuple[int, int, int]]: ) -> list[tuple[int, int, int]]:
"""Build remote FA descriptors for mamba models. """Build remote FA descriptors for mamba models.
Uses transfer_cfg for GQA-aware FA divisor and head-based rank offset Uses TransferTopology for GQA-aware FA divisor and head-based rank
instead of the standard uniform tp_ratio split. offset instead of the standard uniform tp_ratio split.
""" """
assert block_size_ratio == 1, ( assert block_size_ratio == 1, (
"Mamba 3-read transfer with block_size_ratio != 1 is not tested. " "Mamba 3-read transfer with block_size_ratio != 1 is not tested. "
...@@ -915,7 +912,9 @@ class NixlConnectorWorker: ...@@ -915,7 +912,9 @@ class NixlConnectorWorker:
) )
# TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA # TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA
# hetero-TP logic stabilizes. # hetero-TP logic stabilizes.
tp_ratio = transfer_cfg.tp_ratio mamba_info = transfer_topo.get_engine_info(remote_engine_id)
assert isinstance(mamba_info, MambaEngineTransferInfo)
tp_ratio = transfer_topo.tp_ratio(mamba_info.remote_tp_size)
result: list[tuple[int, int, int]] = [] result: list[tuple[int, int, int]] = []
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
local_block_len = self.get_backend_aware_kv_block_len( local_block_len = self.get_backend_aware_kv_block_len(
...@@ -926,9 +925,11 @@ class NixlConnectorWorker: ...@@ -926,9 +925,11 @@ class NixlConnectorWorker:
local_block_len = remote_kv_block_len local_block_len = remote_kv_block_len
if tp_ratio < 0 and not self.use_mla: if tp_ratio < 0 and not self.use_mla:
local_block_len = local_block_len // transfer_cfg.physical_fa_num_reads local_block_len = local_block_len // mamba_info.remote_num_fa_reads
rank_offset = transfer_cfg.fa_rank_offset(remote_kv_block_len) rank_offset = transfer_topo.fa_rank_offset(
remote_engine_id, remote_kv_block_len
)
num_blocks = nixl_agent_meta.num_blocks num_blocks = nixl_agent_meta.num_blocks
page_size = nixl_agent_meta.block_lens[i] page_size = nixl_agent_meta.block_lens[i]
...@@ -937,12 +938,12 @@ class NixlConnectorWorker: ...@@ -937,12 +938,12 @@ class NixlConnectorWorker:
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
result.append((addr, local_block_len, nixl_agent_meta.device_id)) result.append((addr, local_block_len, nixl_agent_meta.device_id))
if kv_topo.is_kv_layout_blocks_first: if transfer_topo.is_kv_layout_blocks_first:
second_split = self.get_backend_aware_kv_block_len( second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=False layer_idx=i, first_split=False, mamba_view=False
) )
if tp_ratio < 0 and not self.use_mla: if tp_ratio < 0 and not self.use_mla:
second_split = second_split // transfer_cfg.physical_fa_num_reads second_split = second_split // mamba_info.remote_num_fa_reads
for block_id in range(num_blocks): for block_id in range(num_blocks):
block_offset = block_id * page_size block_offset = block_id * page_size
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
...@@ -981,15 +982,17 @@ class NixlConnectorWorker: ...@@ -981,15 +982,17 @@ class NixlConnectorWorker:
conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)] conv_offsets = [(0, xb_p), (xb_p, bb_p), (xb_p + bb_p, bb_p)]
ssm_read_size = nixl_agent_meta.ssm_sizes[1] ssm_read_size = nixl_agent_meta.ssm_sizes[1]
remote_ratio = self._mamba_phys_ratio[nixl_agent_meta.engine_id] remote_physical_per_logical = self._physical_blocks_per_logical[
num_blocks = nixl_agent_meta.num_blocks // remote_ratio nixl_agent_meta.engine_id
]
num_blocks = nixl_agent_meta.num_blocks // remote_physical_per_logical
device_id = nixl_agent_meta.device_id device_id = nixl_agent_meta.device_id
result: list[tuple[int, int, int]] = [] result: list[tuple[int, int, int]] = []
# NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case # NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case
# block lengths vary across layers (e.g. MLA). # block lengths vary across layers (e.g. MLA).
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
page_stride = nixl_agent_meta.block_lens[i] * remote_ratio page_stride = nixl_agent_meta.block_lens[i] * remote_physical_per_logical
for off, sz in conv_offsets: for off, sz in conv_offsets:
for blk in range(num_blocks): for blk in range(num_blocks):
result.append((base_addr + blk * page_stride + off, sz, device_id)) result.append((base_addr + blk * page_stride + off, sz, device_id))
...@@ -1019,8 +1022,8 @@ class NixlConnectorWorker: ...@@ -1019,8 +1022,8 @@ class NixlConnectorWorker:
register another local_xfer_handler using remote block len to ensure register another local_xfer_handler using remote block len to ensure
data copy correctness. data copy correctness.
""" """
assert self.kv_topo is not None assert self.transfer_topo is not None
kv_topo = self.kv_topo transfer_topo = self.transfer_topo
block_size_ratio = self.block_size // block_size block_size_ratio = self.block_size // block_size
blocks_data: list[tuple[int, int, int]] = [] blocks_data: list[tuple[int, int, int]] = []
...@@ -1051,7 +1054,7 @@ class NixlConnectorWorker: ...@@ -1051,7 +1054,7 @@ class NixlConnectorWorker:
# (addr, len, device id) # (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.device_id)) blocks_data.append((addr, kv_block_len, self.device_id))
if kv_topo.is_kv_layout_blocks_first: if transfer_topo.is_kv_layout_blocks_first:
second_split = self.get_backend_aware_kv_block_len( second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=mamba layer_idx=i, first_split=False, mamba_view=mamba
) )
...@@ -1153,11 +1156,29 @@ class NixlConnectorWorker: ...@@ -1153,11 +1156,29 @@ class NixlConnectorWorker:
) )
return self._remote_agents[engine_id][remote_tp_rank] return self._remote_agents[engine_id][remote_tp_rank]
### Register remote agent metadata ### Register remote engine in TransferTopology (idempotent).
if engine_id not in self._tp_size: assert self.transfer_topo is not None
self._tp_size[engine_id] = remote_tp_size transfer_topo = self.transfer_topo
if engine_id not in self._block_size: physical_blocks_per_logical = (
self._block_size[engine_id] = nixl_agent_meta.block_size compute_physical_blocks_per_logical(
nixl_agent_meta.ssm_sizes,
nixl_agent_meta.block_lens[0],
)
if self._has_mamba
else 1
)
transfer_topo.register_remote_engine(
remote_engine_id=engine_id,
remote_tp_size=remote_tp_size,
remote_block_size=nixl_agent_meta.block_size,
remote_block_len=nixl_agent_meta.block_lens[0],
remote_physical_blocks_per_logical=physical_blocks_per_logical,
local_block_len=self.block_len_per_layer[0],
)
if self._has_mamba and engine_id not in self._physical_blocks_per_logical:
self._physical_blocks_per_logical[engine_id] = physical_blocks_per_logical
logger.info("Transfer plan: %s", transfer_topo.describe(engine_id))
remote_agent_name = self.nixl_wrapper.add_remote_agent( remote_agent_name = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata nixl_agent_meta.agent_metadata
...@@ -1170,16 +1191,10 @@ class NixlConnectorWorker: ...@@ -1170,16 +1191,10 @@ class NixlConnectorWorker:
# remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
# local origin:| 0| 1| 8| 12| # local origin:| 0| 1| 8| 12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
assert self.kv_topo is not None block_size_ratio = transfer_topo.block_size_ratio(nixl_agent_meta.block_size)
kv_topo = self.kv_topo
block_size_ratio = kv_topo.block_size_ratio_from_engine_id(engine_id)
if engine_id not in self.dst_num_blocks: if engine_id not in self.dst_num_blocks:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
if self._has_mamba:
self._mamba_phys_ratio[engine_id] = compute_mamba_phys_ratio(
nixl_agent_meta.ssm_sizes, nixl_agent_meta.block_lens[0]
)
# Keep track of remote agent kv caches base addresses. # Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id][remote_tp_rank] = ( self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
...@@ -1189,28 +1204,13 @@ class NixlConnectorWorker: ...@@ -1189,28 +1204,13 @@ class NixlConnectorWorker:
# This is 1 when P and D `--tensor-parallel-size` match. Otherwise, # This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
# this is the ratio between the two sizes. # this is the ratio between the two sizes.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id) tp_ratio = transfer_topo.tp_ratio(remote_tp_size)
# Handle tp_size>num_kv_heads: replicate KV cache. # Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote = ( indexes_into_remote = (
not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0 not transfer_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
) )
# Create transfer config (single source of truth for descriptor sizes).
if self._has_mamba and engine_id not in self._transfer_configs:
self._transfer_configs[engine_id] = HeteroTPTransferConfig(
tp_ratio=tp_ratio,
K=kv_topo.total_num_kv_heads,
d_tp=self.world_size,
p_tp=remote_tp_size,
d_rank=self.tp_rank,
use_mla=self.use_mla,
d_block_len=self.block_len_per_layer[0],
p_block_len=nixl_agent_meta.block_lens[0],
is_blocks_first=kv_topo.is_kv_layout_blocks_first,
)
logger.info("Created %s", self._transfer_configs[engine_id].describe())
logger.debug( logger.debug(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
engine_id, engine_id,
...@@ -1231,12 +1231,10 @@ class NixlConnectorWorker: ...@@ -1231,12 +1231,10 @@ class NixlConnectorWorker:
self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
if self._has_mamba: if self._has_mamba:
transfer_cfg = self._transfer_configs.get(engine_id) if transfer_topo.needs_split_handles(engine_id):
assert transfer_cfg is not None
if transfer_cfg.needs_split_handles:
# Mamba-HMA: FA and Mamba use different split factors. # Mamba-HMA: FA and Mamba use different split factors.
for handle_data in transfer_cfg.compute_split_handle_data( for handle_data in transfer_topo.compute_split_handle_data(
self.src_blocks_data, self.num_descs, abs_tp engine_id, self.src_blocks_data, self.num_descs, abs_tp
): ):
descs = self.nixl_wrapper.get_xfer_descs( descs = self.nixl_wrapper.get_xfer_descs(
handle_data, self.nixl_memory_type handle_data, self.nixl_memory_type
...@@ -1247,12 +1245,8 @@ class NixlConnectorWorker: ...@@ -1247,12 +1245,8 @@ class NixlConnectorWorker:
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
logger.info( logger.info(
"Mamba-HMA split handles: targets=%s, fa_reads=%s, " "Mamba-HMA split handles: %s, num_descs=%s",
"fa_entry=%s, mamba_reads=%s, num_descs=%s", transfer_topo.describe(engine_id),
transfer_cfg.transfer_targets,
transfer_cfg.physical_fa_num_reads,
transfer_cfg.fa_entry_size,
transfer_cfg.mamba_num_reads,
self.num_descs, self.num_descs,
) )
else: else:
...@@ -1321,7 +1315,7 @@ class NixlConnectorWorker: ...@@ -1321,7 +1315,7 @@ class NixlConnectorWorker:
(addr, local_block_len, nixl_agent_meta.device_id) (addr, local_block_len, nixl_agent_meta.device_id)
) )
if kv_topo.is_kv_layout_blocks_first: if transfer_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting. # With FlashInfer index V separately to allow head splitting.
second_split = self.get_backend_aware_kv_block_len( second_split = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=False, mamba_view=mamba layer_idx=i, first_split=False, mamba_view=mamba
...@@ -1360,14 +1354,12 @@ class NixlConnectorWorker: ...@@ -1360,14 +1354,12 @@ class NixlConnectorWorker:
engine_id, engine_id,
remote_tp_rank, remote_tp_rank,
) )
transfer_cfg = self._transfer_configs.get(engine_id)
assert transfer_cfg is not None
blocks_data.extend( blocks_data.extend(
self._build_fa_remote_for_mamba( self._build_fa_remote_for_mamba(
nixl_agent_meta, nixl_agent_meta,
transfer_cfg,
block_size_ratio, block_size_ratio,
kv_topo, transfer_topo,
engine_id,
) )
) )
blocks_data.extend( blocks_data.extend(
...@@ -1403,18 +1395,19 @@ class NixlConnectorWorker: ...@@ -1403,18 +1395,19 @@ class NixlConnectorWorker:
""" """
remote_engine_id = nixl_agent_meta.engine_id remote_engine_id = nixl_agent_meta.engine_id
assert self._tp_size[remote_engine_id] == remote_tp_size assert self.transfer_topo is not None
assert self.kv_topo is not None remote_info = self.transfer_topo.get_engine_info(remote_engine_id)
assert remote_info.remote_tp_size == remote_tp_size
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) tp_ratio = self.transfer_topo.tp_ratio(remote_tp_size)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( block_size_ratio = self.transfer_topo.block_size_ratio(
remote_engine_id nixl_agent_meta.block_size
) )
# num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba. # num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba.
# Mamba models can have replicated FA KV with tp_ratio < 0. # Mamba models can have replicated FA KV with tp_ratio < 0.
if not self._has_mamba: if not self._has_mamba:
assert not ( assert not (
tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id) tp_ratio < 0 and self.transfer_topo.is_kv_replicated(remote_engine_id)
) )
if self._is_hma_required: if self._is_hma_required:
...@@ -1467,7 +1460,7 @@ class NixlConnectorWorker: ...@@ -1467,7 +1460,7 @@ class NixlConnectorWorker:
if ( if (
abs(tp_ratio) != 1 abs(tp_ratio) != 1
and not self.use_mla and not self.use_mla
and not self.kv_topo.is_kv_replicated(remote_engine_id) and not self.transfer_topo.is_kv_replicated(remote_engine_id)
and kv_cache_layout != "HND" and kv_cache_layout != "HND"
and not self.enable_permute_local_kv and not self.enable_permute_local_kv
): ):
...@@ -1478,7 +1471,7 @@ class NixlConnectorWorker: ...@@ -1478,7 +1471,7 @@ class NixlConnectorWorker:
# Block len can only vary across layers when using MLA. # Block len can only vary across layers when using MLA.
remote_block_len = nixl_agent_meta.block_lens[0] remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id): if self.use_mla or self.transfer_topo.is_kv_replicated(remote_engine_id):
# With replicated KV cache, only the number of blocks can differ. # With replicated KV cache, only the number of blocks can differ.
# TODO (ZhanqiuHu): For mamba models, validate FA and mamba # TODO (ZhanqiuHu): For mamba models, validate FA and mamba
# block_lens separately. # block_lens separately.
...@@ -1594,7 +1587,7 @@ class NixlConnectorWorker: ...@@ -1594,7 +1587,7 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0: if len(self.device_kv_caches) == 0:
return return
assert block_size_ratio >= 1, "Only nP < nD supported currently." assert block_size_ratio >= 1, "Only nP < nD supported currently."
assert self.kv_topo is not None assert self.transfer_topo is not None
if self.enable_permute_local_kv and block_size_ratio > 1: if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug( logger.debug(
"Post-processing device kv cache on receive by converting " "Post-processing device kv cache on receive by converting "
...@@ -1614,7 +1607,7 @@ class NixlConnectorWorker: ...@@ -1614,7 +1607,7 @@ class NixlConnectorWorker:
block_size_ratio, block_size_ratio,
) )
split_k_and_v = self.kv_topo.split_k_and_v split_k_and_v = self.transfer_topo.split_k_and_v
for block_ids in block_ids_list: for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
...@@ -1661,7 +1654,7 @@ class NixlConnectorWorker: ...@@ -1661,7 +1654,7 @@ class NixlConnectorWorker:
The scheduler process (via the MultiprocExecutor) will use this output The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done. to track which workers are done.
""" """
assert self.kv_topo is not None assert self.transfer_topo is not None
done_sending = self._get_new_notifs() done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers) done_recving = self._pop_done_transfers(self._recving_transfers)
...@@ -1689,8 +1682,9 @@ class NixlConnectorWorker: ...@@ -1689,8 +1682,9 @@ class NixlConnectorWorker:
self.sync_recved_kv_to_device(req_id, meta) self.sync_recved_kv_to_device(req_id, meta)
# post processing for heteroblocksize # post processing for heteroblocksize
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( remote_info = self.transfer_topo.get_engine_info(meta.remote.engine_id)
meta.remote.engine_id block_size_ratio = self.transfer_topo.block_size_ratio(
remote_info.remote_block_size
) )
if not self.use_mla and ( if not self.use_mla and (
block_size_ratio > 1 or self.enable_permute_local_kv block_size_ratio > 1 or self.enable_permute_local_kv
...@@ -1741,7 +1735,7 @@ class NixlConnectorWorker: ...@@ -1741,7 +1735,7 @@ class NixlConnectorWorker:
are reading from the same producer (heterogeneous TP scenario), wait are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling. for all consumers to be done pulling.
""" """
assert self.kv_topo is not None assert self.transfer_topo is not None
notified_req_ids: set[str] = set() notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values(): for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs: for notif in notifs:
...@@ -1760,7 +1754,7 @@ class NixlConnectorWorker: ...@@ -1760,7 +1754,7 @@ class NixlConnectorWorker:
# NOTE: `tp_ratio` is the opposite when swapping local<>remote # NOTE: `tp_ratio` is the opposite when swapping local<>remote
n_consumers = int(tp_size) n_consumers = int(tp_size)
tp_ratio = self.kv_topo.tp_ratio(n_consumers) tp_ratio = self.transfer_topo.tp_ratio(n_consumers)
# Number of reads *per producer* to wait for. # Number of reads *per producer* to wait for.
# When remote D TP > local P TP we expect `tp_ratio` reads. # When remote D TP > local P TP we expect `tp_ratio` reads.
...@@ -1901,17 +1895,17 @@ class NixlConnectorWorker: ...@@ -1901,17 +1895,17 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None and self.kv_topo is not None assert meta.remote is not None and self.transfer_topo is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( engine_id = meta.remote.engine_id
meta.remote.engine_id remote_ranks = self.transfer_topo.target_remote_ranks(engine_id)
) remote_info = self.transfer_topo.get_engine_info(engine_id)
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id) tp_ratio = self.transfer_topo.tp_ratio(remote_info.remote_tp_size)
if self._has_mamba: if self._has_mamba:
# Expand remote logical → kernel block IDs. # Expand remote logical → kernel block IDs.
meta.remote.block_ids = self._logical_to_remote_kernel_block_ids( meta.remote.block_ids = self._logical_to_remote_kernel_block_ids(
meta.remote.block_ids, meta.remote.block_ids,
self._mamba_phys_ratio[meta.remote.engine_id], self._physical_blocks_per_logical[meta.remote.engine_id],
) )
else: else:
meta.remote.block_ids = self._logical_to_kernel_block_ids( meta.remote.block_ids = self._logical_to_kernel_block_ids(
...@@ -1924,7 +1918,7 @@ class NixlConnectorWorker: ...@@ -1924,7 +1918,7 @@ class NixlConnectorWorker:
# the first remote rank (cache is duplicated).. # the first remote rank (cache is duplicated)..
break break
remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id] remote_block_size = remote_info.remote_block_size
logger.debug( logger.debug(
"Remote agent %s available, calling _read_blocks" "Remote agent %s available, calling _read_blocks"
" on remote rank %s with remote block size %s for req %s", " on remote rank %s with remote block size %s for req %s",
...@@ -1955,9 +1949,8 @@ class NixlConnectorWorker: ...@@ -1955,9 +1949,8 @@ class NixlConnectorWorker:
remote_ids: BlockIds = meta.remote.block_ids remote_ids: BlockIds = meta.remote.block_ids
if self._has_mamba: if self._has_mamba:
# Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets. # Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets.
transfer_cfg = self._transfer_configs.get(meta.remote.engine_id) local_ids, remote_ids = self.transfer_topo.filter_block_ids_for_rank(
assert transfer_cfg is not None engine_id,
local_ids, remote_ids = transfer_cfg.filter_block_ids_for_rank(
remote_rank, remote_rank,
local_ids, local_ids,
remote_ids, remote_ids,
...@@ -1999,8 +1992,11 @@ class NixlConnectorWorker: ...@@ -1999,8 +1992,11 @@ class NixlConnectorWorker:
Post a READ point-to-point xfer request from a single local worker to Post a READ point-to-point xfer request from a single local worker to
a single remote worker. a single remote worker.
""" """
assert self.kv_topo is not None assert self.transfer_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) remote_info = self.transfer_topo.get_engine_info(dst_engine_id)
block_size_ratio = self.transfer_topo.block_size_ratio(
remote_info.remote_block_size
)
if block_size_ratio > 1: if block_size_ratio > 1:
# TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups. # TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups.
assert not self._is_hma_required assert not self._is_hma_required
...@@ -2190,8 +2186,8 @@ class NixlConnectorWorker: ...@@ -2190,8 +2186,8 @@ class NixlConnectorWorker:
# This is like having two "low-level views" of the same storage. # This is like having two "low-level views" of the same storage.
# `num_fa_descs` offset must be computed per-engine since P and D can # `num_fa_descs` offset must be computed per-engine since P and D can
# have different num_blocks (and thus different FA descs counts). # have different num_blocks (and thus different FA descs counts).
ratio = self._mamba_phys_ratio[engine_id] physical_per_logical = self._physical_blocks_per_logical[engine_id]
logical_blocks = num_blocks // ratio logical_blocks = num_blocks // physical_per_logical
num_fa_descs = self.num_regions * num_blocks num_fa_descs = self.num_regions * num_blocks
# 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm). # 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm).
mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None] mamba_region_ids = np.arange(len(self.block_len_per_layer) * 4)[:, None]
...@@ -2234,21 +2230,22 @@ class NixlConnectorWorker: ...@@ -2234,21 +2230,22 @@ class NixlConnectorWorker:
] ]
def _logical_to_remote_kernel_block_ids( def _logical_to_remote_kernel_block_ids(
self, block_ids: BlockIds, remote_ratio: int self, block_ids: BlockIds, remote_physical_per_logical: int
) -> BlockIds: ) -> BlockIds:
"""Map logical block IDs to physical kernel block IDs on the remote. """Map logical block IDs to physical kernel block IDs on the remote.
Args: Args:
block_ids: per-group lists of logical block IDs. block_ids: per-group lists of logical block IDs.
remote_ratio: remote engine's physical blocks per logical block. remote_physical_per_logical: remote engine's physical blocks
per logical block.
Returns: Returns:
Same structure with FA groups expanded (each logical block L Same structure with FA groups expanded (each logical block L
becomes kernel blocks [L*remote_ratio .. L*remote_ratio + becomes kernel blocks [L*ratio .. L*ratio + local_ratio - 1]).
local_ratio - 1]). Mamba groups are passed through unchanged. Mamba groups are passed through unchanged.
""" """
local_ratio = self._physical_blocks_per_logical_kv_block local_ratio = self._physical_blocks_per_logical_kv_block
if remote_ratio == 1: if remote_physical_per_logical == 1:
return block_ids return block_ids
local_arange = np.arange(local_ratio).reshape(1, -1) local_arange = np.arange(local_ratio).reshape(1, -1)
group_specs = self.kv_cache_config.kv_cache_groups group_specs = self.kv_cache_config.kv_cache_groups
...@@ -2256,7 +2253,7 @@ class NixlConnectorWorker: ...@@ -2256,7 +2253,7 @@ class NixlConnectorWorker:
for i, group in enumerate(block_ids): for i, group in enumerate(block_ids):
if not isinstance(group_specs[i].kv_cache_spec, MambaSpec): if not isinstance(group_specs[i].kv_cache_spec, MambaSpec):
arr = np.array(group).reshape(-1, 1) arr = np.array(group).reshape(-1, 1)
expanded = (arr * remote_ratio + local_arange).flatten() expanded = (arr * remote_physical_per_logical + local_arange).flatten()
result.append(expanded.tolist()) result.append(expanded.tolist())
else: else:
# Mamba blocks are 1:1 logical-to-physical (no expansion). # Mamba blocks are 1:1 logical-to-physical (no expansion).
...@@ -2296,8 +2293,8 @@ class NixlConnectorWorker: ...@@ -2296,8 +2293,8 @@ class NixlConnectorWorker:
+-------------------+ +--------------------+ +-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split | |1st_split-2nd_split| |1st_split-2nd_split |
""" """
assert self.kv_topo is not None assert self.transfer_topo is not None
if self.kv_topo.is_kv_layout_blocks_first: if self.transfer_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part). # For indexing only half (either just the K or V part).
if mamba_view: if mamba_view:
# NOTE (NickLucche) Mamba Opt: this is already skipping the padding so # NOTE (NickLucche) Mamba Opt: this is already skipping the padding so
......
...@@ -151,7 +151,9 @@ def derive_mamba_conv_split( ...@@ -151,7 +151,9 @@ def derive_mamba_conv_split(
) )
def compute_mamba_phys_ratio(ssm_sizes: tuple[int, ...], block_len: int) -> int: def compute_physical_blocks_per_logical(
ssm_sizes: tuple[int, ...], block_len: int
) -> int:
"""Derive _physical_blocks_per_logical_kv_block from remote metadata. """Derive _physical_blocks_per_logical_kv_block from remote metadata.
The remote engine's ratio is not sent directly in the handshake, so we The remote engine's ratio is not sent directly in the handshake, so we
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment