Unverified Commit bfdc0a3a authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

[NIXL][Mamba][3/N] Heterogeneous TP: 3-read conv state transfer (#37635)

parent 93bada49
......@@ -19,9 +19,9 @@ dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
)
hybrid_ssm_configs=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
)
sw_attn_configs=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
......
......@@ -224,6 +224,8 @@ def test_get_block_descs_ids_hybrid_ssm():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1
worker._mamba_phys_ratio = {engine_id: 1}
worker.block_len_per_layer = [100]
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks
......@@ -234,9 +236,10 @@ def test_get_block_descs_ids_hybrid_ssm():
# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# offset=num_descs=200
# region0: [201, 202], region1: [301, 302]
expected = [3, 5, 103, 105, 201, 202, 301, 302]
# offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm)
# region0: [201, 202], region1: [301, 302],
# region2: [401, 402], region3: [501, 502]
expected = [3, 5, 103, 105, 201, 202, 301, 302, 401, 402, 501, 502]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
......@@ -259,6 +262,8 @@ def test_get_block_descs_ids_kernel_block_mismatch():
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio
worker._mamba_phys_ratio = {engine_id: ratio}
worker.block_len_per_layer = [100]
worker.num_descs = 2 * num_blocks # 800
fa_blocks = [3, 7] # kernel-level block IDs
......@@ -267,9 +272,11 @@ def test_get_block_descs_ids_kernel_block_mismatch():
# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
# region0: [801, 802], region1: [901, 902]
expected = [3, 7, 403, 407, 801, 802, 901, 902]
# SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800,
# 4 regions per Mamba layer (x, B, C, ssm)
# region0: [801, 802], region1: [901, 902],
# region2: [1001, 1002], region3: [1101, 1102]
expected = [3, 7, 403, 407, 801, 802, 901, 902, 1001, 1002, 1101, 1102]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
......@@ -418,3 +425,29 @@ def test_has_mamba_init(
)
assert scheduler._has_mamba is expected_has_mamba
assert scheduler._is_hma_required is expected_is_hma
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"ssm_sizes,block_len,expected_ratio",
[
# Nemotron 30B TP=1: ceil((36864 + 2097152) / 8192) = 261
((36864, 2097152), 8192, 261),
# Nemotron 30B TP=2: ceil((18432 + 1048576) / 4096) = 261
((18432, 1048576), 4096, 261),
# Nemotron 30B TP=4: ceil((9216 + 524288) / 4096) = 131
((9216, 524288), 4096, 131),
],
)
def test_compute_mamba_phys_ratio(ssm_sizes, block_len, expected_ratio):
"""Verify that compute_mamba_phys_ratio is TP-dependent.
With dimension-sharded Mamba state, the ratio differs across TP sizes
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
_mamba_phys_ratio must be stored per-engine.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
compute_mamba_phys_ratio,
)
assert compute_mamba_phys_ratio(ssm_sizes, block_len) == expected_ratio
......@@ -5,7 +5,7 @@ KV cache helper for store.
"""
from collections.abc import Iterator
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, cast
import torch
......@@ -516,6 +516,338 @@ class TpKVTopology:
return cache if self.split_k_and_v else [cache]
# ---- Mamba-HMA hetero-TP transfer config ----
#
# Key insight: with hetero-TP (P_TP > D_TP), FA KV cache may be
# replicated across P ranks (when P_TP > num_kv_heads), but Mamba
# conv/SSM state is almost always uniquely sharded per P rank. So the
# number of P ranks D must read from can differ between FA and Mamba,
# and they must be handled separately.
def _physical_head_range(tp_size: int, num_heads: int, rank: int) -> range:
"""Physical KV head range stored in a rank's KV cache tensor.
When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank.
When ``tp_size > num_heads``: 1 physical head per rank. Heads are
distributed **contiguously** (matching vLLM's GQA weight partitioning):
consecutive ranks share a head before moving to the next one.
"""
if tp_size <= num_heads:
assert num_heads % tp_size == 0
per_rank = num_heads // tp_size
return range(rank * per_rank, (rank + 1) * per_rank)
else:
h = rank * num_heads // tp_size
return range(h, h + 1)
def _range_overlap(a: range, b: range) -> range:
start = max(a.start, b.start)
stop = min(a.stop, b.stop)
return range(start, max(start, stop))
@dataclass
class HeteroTPTransferConfig:
"""Precomputed transfer plan for one (D rank, P engine) pair.
Currently only instantiated for Mamba-HMA (hybrid SSM+Attention) models
where FA and mamba require different splitting factors. Could be extended
to other model types that need non-uniform hetero-TP transfer sizing.
All descriptor sizes are computed here. The guarantee is:
local_entry_size == remote_entry_size (for NIXL)
Attributes that start with ``fa_`` concern FlashAttention KV cache.
Attributes that start with ``mamba_`` concern Mamba conv/SSM state.
"""
# ---- Input parameters (from handshake) ----
tp_ratio: int
K: int # total_num_kv_heads (before TP sharding)
d_tp: int # D engine's tensor_parallel_size
p_tp: int # P engine's tensor_parallel_size
d_rank: int # this D worker's TP rank
use_mla: bool
# Per-layer block lengths (bytes, K+V combined for blocks_first).
# Uniform across layers for current models.
d_block_len: int # D's block_len_per_layer (representative)
p_block_len: int # P's block_len_per_layer (from handshake)
is_blocks_first: bool # kv_topo.is_kv_layout_blocks_first
# ---- Derived: computed in __post_init__ ----
#
# Physical heads per rank (what the KV tensor actually stores)
d_physical_heads: int = field(init=False)
p_physical_heads: int = field(init=False)
# How many distinct P ranks D needs for FA data
physical_fa_num_reads: int = field(init=False)
# Which P ranks contribute unique FA heads (ordered by head index)
fa_read_targets: list[int] = field(init=False)
# All P ranks needed for mamba (always abs_tp for tp_ratio < 0)
mamba_num_reads: int = field(init=False)
# All P ranks this D rank communicates with (FA ∪ mamba)
transfer_targets: list[int] = field(init=False)
# FA descriptor entry size (K or V side, for blocks_first layout)
# Guaranteed: fa_entry_size is the SAME for local handle AND remote desc.
fa_entry_size: int = field(init=False)
# Replication flags
is_d_replicated: bool = field(init=False)
is_p_replicated: bool = field(init=False)
# Pre-built set for fast lookup
_fa_target_set: frozenset[int] = field(init=False, repr=False)
# Map: P rank → index in fa_read_targets (for head slot offset)
_fa_target_index: dict[int, int] = field(init=False, repr=False)
def __post_init__(self) -> None:
K = self.K
self.is_d_replicated = self.d_tp > K
self.is_p_replicated = self.p_tp > K
self.d_physical_heads = max(1, K // self.d_tp)
self.p_physical_heads = max(1, K // self.p_tp)
abs_tp = -self.tp_ratio if self.tp_ratio < 0 else 1
# ---- Mamba range (computed first so FA can prefer ranks in it) ----
mamba_range: range | None = None
if self.tp_ratio < 0:
mamba_range = range(self.d_rank * abs_tp, (self.d_rank + 1) * abs_tp)
# ---- FA read targets ----
if self.use_mla or self.tp_ratio >= 0:
self.physical_fa_num_reads = 1
self.fa_read_targets = (
[0]
if self.use_mla
# Must match kv_topo.get_target_remote_ranks (d_rank // tp_ratio).
else [
self.d_rank // self.tp_ratio if self.tp_ratio > 0 else self.d_rank
]
)
else:
d_needs = _physical_head_range(self.d_tp, K, self.d_rank)
# When mamba range exists, prefer P ranks within it so that
# FA targets are a subset of mamba transfer_targets (avoids
# orphaned FA targets outside the transfer loop).
search_range = mamba_range if mamba_range is not None else range(self.p_tp)
seen: set[tuple[int, int]] = set()
targets: list[int] = []
for p in search_range:
p_has = _physical_head_range(self.p_tp, K, p)
ov = _range_overlap(d_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
targets.append(p)
if not targets:
# Fallback: search globally (should not happen in practice)
for p in range(self.p_tp):
p_has = _physical_head_range(self.p_tp, K, p)
ov = _range_overlap(d_needs, p_has)
if len(ov) > 0:
key = (ov.start, ov.stop)
if key not in seen:
seen.add(key)
targets.append(p)
self.fa_read_targets = targets
self.physical_fa_num_reads = len(targets)
self._fa_target_set = frozenset(self.fa_read_targets)
self._fa_target_index = {r: i for i, r in enumerate(self.fa_read_targets)}
# ---- Mamba targets ----
if mamba_range is not None and abs_tp > self.physical_fa_num_reads:
self.mamba_num_reads = abs_tp
self.transfer_targets = list(mamba_range)
else:
self.mamba_num_reads = self.physical_fa_num_reads
self.transfer_targets = list(self.fa_read_targets)
# ---- FA entry size ----
# For blocks_first: block_len_per_layer includes K+V; // 2 gives K (or V).
# Use min(D, P) because D indexes into P when tp_ratio > 0,
# and P is the natural unit when tp_ratio < 0.
effective_block_len = min(self.d_block_len, self.p_block_len)
if self.is_blocks_first:
self.fa_entry_size = effective_block_len // 2
else:
self.fa_entry_size = effective_block_len
self._validate()
def _validate(self) -> None:
"""Cross-check internal consistency."""
if self.is_d_replicated and self.is_p_replicated and self.tp_ratio > 0:
logger.info(
"Both-replicated hetero-TP: D_TP=%d > P_TP=%d > K=%d. "
"Using d_rank // tp_ratio routing with relative head offset.",
self.d_tp,
self.p_tp,
self.K,
)
# FA targets must be a subset of transfer_targets
tt_set = set(self.transfer_targets)
for t in self.fa_read_targets:
if t not in tt_set:
logger.error(
"FA target P rank %d is NOT in transfer_targets %s. "
"This will cause missed FA reads!",
t,
self.transfer_targets,
)
# For tp_ratio < 0 with blocks_first: D_K_half / reads should == P_K_half
if (
self.is_blocks_first
and self.tp_ratio < 0
and self.physical_fa_num_reads > 0
):
d_k_half = self.d_block_len // 2
p_k_half = self.p_block_len // 2
expected_local = d_k_half // self.physical_fa_num_reads
if expected_local != p_k_half:
logger.warning(
"FA size mismatch: D_K_half=%d / reads=%d = %d, "
"but P_K_half=%d. This may indicate a head count or "
"Mamba-HMA inflation inconsistency.",
d_k_half,
self.physical_fa_num_reads,
expected_local,
p_k_half,
)
# ---- Query methods ----
def should_skip_fa(self, p_rank: int) -> bool:
"""Whether to skip FA groups for this P rank (mamba-only transfer)."""
return p_rank not in self._fa_target_set
def fa_head_slot(self, p_rank: int) -> int:
"""Index into D's FA block for this P rank's head data.
For P ranks in fa_read_targets, returns 0, 1, ..., reads-1.
For P ranks NOT in fa_read_targets (replicated duplicates),
returns the slot of the matching FA target with the same head.
"""
if p_rank in self._fa_target_index:
return self._fa_target_index[p_rank]
# Duplicate head: find which fa_target has the same physical head
p_head = _physical_head_range(self.p_tp, self.K, p_rank)
for target in self.fa_read_targets:
t_head = _physical_head_range(self.p_tp, self.K, target)
if _range_overlap(p_head, t_head):
return self._fa_target_index[target]
return 0 # fallback
def fa_rank_offset(self, remote_kv_block_len: int) -> int:
"""Byte offset into P's FA block for this D rank.
When D is replicated (D_TP > K), multiple D ranks share a head.
Computes offset *relative to the target P rank's first head*
so it works regardless of how many heads P has.
When neither side replicates, falls back to tp_rank % tp_ratio.
Returns 0 when D does not index into P's block.
"""
if self.use_mla or self.tp_ratio <= 0:
return 0
if self.is_d_replicated:
d_head = self.d_rank * self.K // self.d_tp
p_rank = self.fa_read_targets[0]
p_start = p_rank * self.K // self.p_tp
return (d_head - p_start) * remote_kv_block_len
return self.d_rank % self.tp_ratio * remote_kv_block_len
@property
def needs_split_handles(self) -> bool:
"""Whether per-P-rank split handles are needed.
True when FA and mamba have different read counts, requiring
different splitting factors in the local handle.
"""
return self.tp_ratio < 0 and not self.use_mla and len(self.transfer_targets) > 1
def compute_split_handle_data(
self,
src_blocks_data: list[tuple[int, int, int]],
num_fa_descs: int,
abs_tp: int,
) -> list[list[tuple[int, int, int]]]:
"""Compute per-P-rank (addr, len, tp) triples for Mamba-HMA split handles.
FA descriptors (indices < num_fa_descs) are sliced by
``physical_fa_num_reads``; mamba descriptors are sliced uniformly
by ``abs_tp``.
Returns one list of triples per transfer target.
"""
all_handle_data: list[list[tuple[int, int, int]]] = []
for p_idx, p_rank in enumerate(self.transfer_targets):
handle_data: list[tuple[int, int, int]] = []
skip_fa = self.should_skip_fa(p_rank)
fa_slot = self.fa_head_slot(p_rank) if not skip_fa else 0
for j, (addr, local_len, tp) in enumerate(src_blocks_data):
if j < num_fa_descs:
assert self.physical_fa_num_reads >= 1
fa_chunk = local_len // self.physical_fa_num_reads
handle_data.append((addr + fa_slot * fa_chunk, fa_chunk, tp))
else:
mamba_chunk = local_len // abs_tp
handle_data.append((addr + p_idx * mamba_chunk, mamba_chunk, tp))
all_handle_data.append(handle_data)
return all_handle_data
def filter_block_ids_for_rank(
self,
remote_rank: int,
local_ids: BlockIds,
remote_ids: BlockIds,
is_mamba_group: list[bool],
) -> tuple[BlockIds, BlockIds]:
"""Zero out FA groups for P ranks outside fa_read_targets.
Returns (filtered_local_ids, filtered_remote_ids). When the
remote rank carries FA data for this D rank, returns the inputs
unchanged.
"""
if not self.should_skip_fa(remote_rank):
return local_ids, remote_ids
num_groups = len(local_ids)
filtered_local: list[list[int]] = [
[] if not is_mamba_group[g] else local_ids[g] for g in range(num_groups)
]
filtered_remote: list[list[int]] = [
[] if not is_mamba_group[g] else remote_ids[g] for g in range(num_groups)
]
return filtered_local, filtered_remote
def describe(self) -> str:
"""One-line summary for logging."""
return (
f"HeteroTPTransferConfig("
f"tp_ratio={self.tp_ratio}, K={self.K}, "
f"d_tp={self.d_tp}, p_tp={self.p_tp}, d_rank={self.d_rank}, "
f"physical_fa_reads={self.physical_fa_num_reads}, "
f"mamba_reads={self.mamba_num_reads}, "
f"fa_targets={self.fa_read_targets}, "
f"transfer_targets={self.transfer_targets}, "
f"fa_entry_size={self.fa_entry_size}, "
f"d_block_len={self.d_block_len}, p_block_len={self.p_block_len})"
)
def get_current_attn_backends(
vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> list[type[AttentionBackend]]:
......@@ -559,3 +891,50 @@ def get_current_attn_backend(
) -> type[AttentionBackend]:
"""Get the first attention backend for the given layers."""
return get_current_attn_backends(vllm_config, layer_names)[0]
# TODO (ZhanqiuHu): Consolidate TpKVTopology and HeteroTPTransferConfig
# into a single engine-agnostic TransferTopology class.
# 6 of 9 HeteroTPTransferConfig init fields duplicate TpKVTopology data.
#
# @dataclass
# class EngineTransferInfo:
# """Per-remote-engine transfer state, computed at handshake."""
# p_tp: int
# tp_ratio: int
# p_block_len: int
# block_size: int
# # Mamba-specific (None for non-mamba models)
# fa_read_targets: list[int] | None = None
# transfer_targets: list[int] | None = None
# physical_fa_num_reads: int | None = None
# mamba_num_reads: int | None = None
# fa_entry_size: int | None = None
#
# class TransferTopology:
# """Single source of truth for TP topology + transfer sizing."""
# # Shared (set once at init, replaces duplicate fields)
# tp_rank: int # == TpKVTopology.tp_rank == HeteroTP.d_rank
# tp_size: int # == TpKVTopology.tp_size == HeteroTP.d_tp
# total_num_kv_heads: int # == HeteroTP.K
# is_mla: bool # == HeteroTP.use_mla
# is_mamba: bool
# is_blocks_first: bool # == HeteroTP.is_blocks_first
# d_block_len: int
#
# # Per-engine (populated via register_engine() at handshake)
# _engines: dict[EngineId, EngineTransferInfo]
#
# def register_engine(self, engine_id, p_tp, p_block_len, ...): ...
#
# # General (from TpKVTopology)
# def tp_ratio(self, engine_id) -> int: ...
# def target_remote_ranks(self, engine_id) -> list[int]: ...
# def is_kv_replicated(self, engine_id) -> bool: ...
#
# # Mamba-specific (from HeteroTPTransferConfig, gated by is_mamba)
# def fa_rank_offset(self, engine_id, block_len) -> int: ...
# def physical_fa_num_reads(self, engine_id) -> int: ...
# def transfer_targets(self, engine_id) -> list[int]: ...
# def should_skip_fa(self, engine_id, p_rank) -> bool: ...
# def filter_block_ids_for_rank(self, engine_id, ...) -> ...: ...
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Mamba conv-state sub-projection decomposition for the 3-read transfer.
With DS conv state layout (dim, state_len), x/B/C sub-projections are
contiguous in memory. Each D rank reads its x, B, C slices via 3
separate RDMA transfers — no P-side permutation needed.
"""
import math
from dataclasses import dataclass
import torch
from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first
from vllm.v1.kv_cache_interface import MambaSpec
@dataclass(frozen=True)
class MambaConvSplitInfo:
"""Per-rank byte sizes of x, B, C sub-projections in the Mamba conv state.
Used by both P and D sides for NIXL descriptor registration.
All fields are LOCAL to this engine's TP (already divided by TP size).
DS memory layout within one page (contiguous in memory):
|--- x (x_local * conv_rows) ---|- B (b_local * conv_rows) -|- C -|
"""
conv_rows: int # conv_kernel - 1 (typically 3)
x_local: int # intermediate_size / TP (columns for x)
b_local: int # groups_ss / TP (columns for B; C is same size)
conv_dtype_size: int # bytes per element (e.g. 2 for float16)
@property
def conv_dim_local(self) -> int:
"""Total conv columns per rank: x + B + C."""
return self.x_local + 2 * self.b_local
@property
def x_bytes(self) -> int:
"""Byte size of the x sub-projection for one rank."""
return self.x_local * self.conv_rows * self.conv_dtype_size
@property
def b_bytes(self) -> int:
"""Byte size of the B (or C) sub-projection for one rank."""
return self.b_local * self.conv_rows * self.conv_dtype_size
@property
def local_conv_offsets(self) -> list[tuple[int, int]]:
"""(byte_offset, byte_size) of x, B, C within this engine's page.
Used by both P and D for local descriptor registration.
"""
xb = self.x_bytes
bb = self.b_bytes
return [(0, xb), (xb, bb), (xb + bb, bb)]
def remote_conv_offsets(
self, local_rank_offset: int, tp_ratio: int
) -> list[tuple[int, int]]:
"""(byte_offset, byte_size) of this D rank's x, B, C slice within
one P page.
Used by D side only, during remote descriptor registration.
Args:
local_rank_offset: which slice this D rank reads.
tp_ratio > 0: tp_rank % tp_ratio (selects slice of P's page).
tp_ratio < 0: always 0 (read P's full page).
tp_ratio: effective ratio (>= 1 when D_TP > P_TP, 1 when
P_TP > D_TP since each P rank is read in full).
"""
xb = self.x_bytes
bb = self.b_bytes
xr = xb * tp_ratio # full remote x section in bytes
br = bb * tp_ratio # full remote B section in bytes
return [
(local_rank_offset * xb, xb),
(xr + local_rank_offset * bb, bb),
(xr + br + local_rank_offset * bb, bb),
]
def derive_mamba_conv_split(
mamba_spec: MambaSpec,
local_tp: int,
) -> MambaConvSplitInfo:
"""Derive per-rank x/B/C byte sizes from a MambaSpec.
Called once at init on both P and D. Decomposes the conv dimension
(= intermediate_size + 2 * groups_ss) into its x, B, C parts.
Args:
mamba_spec: MambaSpec whose shapes are:
shapes[0] = conv state: (conv_dim_local, conv_rows) in DS layout.
shapes[1] = SSM temporal: (local_num_heads, head_dim).
local_tp: this engine's tensor-parallel size.
Returns:
MambaConvSplitInfo with per-rank x_local, b_local, conv_rows, and
conv_dtype_size.
"""
if mamba_spec.mamba_type != "mamba2":
raise NotImplementedError(
f"3-read conv transfer only supports Mamba2 models, "
f"got mamba_type={mamba_spec.mamba_type!r}. "
f"Mamba1 SSM temporal shape is (intermediate_size // tp, state_size) "
f"which cannot be used to reconstruct intermediate_size."
)
conv_shape = mamba_spec.shapes[0]
assert len(conv_shape) == 2, f"Expected 2D conv state shape, got {conv_shape}"
# NOTE (ZhanqiuHu): 3-read requires DS layout, which is already asserted
# in nixl_connector __init__. Use it directly instead of heuristic detection.
assert is_conv_state_dim_first(), "3-read requires DS conv state layout"
local_conv_dim = conv_shape[0] # DS: (conv_dim_local, conv_rows)
conv_rows = conv_shape[1]
# NOTE (ZhanqiuHu): intermediate_size (= global x dim) is not stored
# in MambaSpec, so we reconstruct it from the SSM temporal state shape:
# shapes[1] = (local_num_heads, head_dim), already divided by TP.
head_dim = mamba_spec.shapes[1][1]
local_num_heads = mamba_spec.shapes[1][0]
intermediate_size = local_num_heads * local_tp * head_dim
# NOTE (ZhanqiuHu): global conv dim = intermediate_size + 2 * groups_ss,
# where groups_ss is the B (= C) dimension. B and C are always the same
# size, so we recover groups_ss from the remainder after subtracting x.
remainder = local_conv_dim * local_tp - intermediate_size
assert remainder > 0 and remainder % 2 == 0, (
f"Conv dim ({local_conv_dim}*tp={local_tp}) doesn't decompose into "
f"intermediate_size={intermediate_size} + 2*groups_ss. "
f"remainder={remainder}"
)
groups_ss = remainder // 2
conv_dtype_size = torch.tensor(
[],
dtype=mamba_spec.dtypes[0], # type: ignore[misc]
).element_size()
# Divide by TP to get per-rank column counts.
return MambaConvSplitInfo(
conv_rows=conv_rows,
x_local=intermediate_size // local_tp,
b_local=groups_ss // local_tp,
conv_dtype_size=conv_dtype_size,
)
def compute_mamba_phys_ratio(ssm_sizes: tuple[int, ...], block_len: int) -> int:
"""Derive _physical_blocks_per_logical_kv_block from remote metadata.
The remote engine's ratio is not sent directly in the handshake, so we
reconstruct it: total mamba state per logical block / block_len.
Args:
ssm_sizes: (conv_state_bytes, ssm_state_bytes) from NixlAgentMetadata.
block_len: the engine's block_len in bytes (from block_lens[0]).
"""
return math.ceil((ssm_sizes[0] + ssm_sizes[1]) / block_len)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment