Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bfdc0a3a
Unverified
Commit
bfdc0a3a
authored
Apr 06, 2026
by
zhanqiuhu
Committed by
GitHub
Apr 06, 2026
Browse files
[NIXL][Mamba][3/N] Heterogeneous TP: 3-read conv state transfer (#37635)
parent
93bada49
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
970 additions
and
75 deletions
+970
-75
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
..._connector/nixl_integration/config_sweep_accuracy_test.sh
+2
-2
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+39
-6
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+380
-1
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+385
-66
vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py
...ed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py
+164
-0
No files found.
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
View file @
bfdc0a3a
...
...
@@ -19,9 +19,9 @@ dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP2, D-DPEP=2 (TP=1)
)
hybrid_ssm_configs
=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
"
VLLM_SSM_CONV_STATE_LAYOUT=DS
ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
"
VLLM_SSM_CONV_STATE_LAYOUT=DS
ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
)
sw_attn_configs
=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
...
...
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
View file @
bfdc0a3a
...
...
@@ -224,6 +224,8 @@ def test_get_block_descs_ids_hybrid_ssm():
worker
.
_has_mamba
=
True
worker
.
_is_mamba_group
=
[
False
,
True
]
worker
.
_physical_blocks_per_logical_kv_block
=
1
worker
.
_mamba_phys_ratio
=
{
engine_id
:
1
}
worker
.
block_len_per_layer
=
[
100
]
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker
.
num_descs
=
2
*
num_blocks
...
...
@@ -234,9 +236,10 @@ def test_get_block_descs_ids_hybrid_ssm():
# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# offset=num_descs=200
# region0: [201, 202], region1: [301, 302]
expected
=
[
3
,
5
,
103
,
105
,
201
,
202
,
301
,
302
]
# offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm)
# region0: [201, 202], region1: [301, 302],
# region2: [401, 402], region3: [501, 502]
expected
=
[
3
,
5
,
103
,
105
,
201
,
202
,
301
,
302
,
401
,
402
,
501
,
502
]
assert
list
(
result
)
==
expected
,
f
"Expected
{
expected
}
, got
{
list
(
result
)
}
"
...
...
@@ -259,6 +262,8 @@ def test_get_block_descs_ids_kernel_block_mismatch():
worker
.
_has_mamba
=
True
worker
.
_is_mamba_group
=
[
False
,
True
]
worker
.
_physical_blocks_per_logical_kv_block
=
ratio
worker
.
_mamba_phys_ratio
=
{
engine_id
:
ratio
}
worker
.
block_len_per_layer
=
[
100
]
worker
.
num_descs
=
2
*
num_blocks
# 800
fa_blocks
=
[
3
,
7
]
# kernel-level block IDs
...
...
@@ -267,9 +272,11 @@ def test_get_block_descs_ids_kernel_block_mismatch():
# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
# region0: [801, 802], region1: [901, 902]
expected
=
[
3
,
7
,
403
,
407
,
801
,
802
,
901
,
902
]
# SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800,
# 4 regions per Mamba layer (x, B, C, ssm)
# region0: [801, 802], region1: [901, 902],
# region2: [1001, 1002], region3: [1101, 1102]
expected
=
[
3
,
7
,
403
,
407
,
801
,
802
,
901
,
902
,
1001
,
1002
,
1101
,
1102
]
assert
list
(
result
)
==
expected
,
f
"Expected
{
expected
}
, got
{
list
(
result
)
}
"
...
...
@@ -418,3 +425,29 @@ def test_has_mamba_init(
)
assert
scheduler
.
_has_mamba
is
expected_has_mamba
assert
scheduler
.
_is_hma_required
is
expected_is_hma
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
parametrize
(
"ssm_sizes,block_len,expected_ratio"
,
[
# Nemotron 30B TP=1: ceil((36864 + 2097152) / 8192) = 261
((
36864
,
2097152
),
8192
,
261
),
# Nemotron 30B TP=2: ceil((18432 + 1048576) / 4096) = 261
((
18432
,
1048576
),
4096
,
261
),
# Nemotron 30B TP=4: ceil((9216 + 524288) / 4096) = 131
((
9216
,
524288
),
4096
,
131
),
],
)
def
test_compute_mamba_phys_ratio
(
ssm_sizes
,
block_len
,
expected_ratio
):
"""Verify that compute_mamba_phys_ratio is TP-dependent.
With dimension-sharded Mamba state, the ratio differs across TP sizes
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
_mamba_phys_ratio must be stored per-engine.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils
import
(
compute_mamba_phys_ratio
,
)
assert
compute_mamba_phys_ratio
(
ssm_sizes
,
block_len
)
==
expected_ratio
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
bfdc0a3a
...
...
@@ -5,7 +5,7 @@ KV cache helper for store.
"""
from
collections.abc
import
Iterator
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
cast
import
torch
...
...
@@ -516,6 +516,338 @@ class TpKVTopology:
return
cache
if
self
.
split_k_and_v
else
[
cache
]
# ---- Mamba-HMA hetero-TP transfer config ----
#
# Key insight: with hetero-TP (P_TP > D_TP), FA KV cache may be
# replicated across P ranks (when P_TP > num_kv_heads), but Mamba
# conv/SSM state is almost always uniquely sharded per P rank. So the
# number of P ranks D must read from can differ between FA and Mamba,
# and they must be handled separately.
def
_physical_head_range
(
tp_size
:
int
,
num_heads
:
int
,
rank
:
int
)
->
range
:
"""Physical KV head range stored in a rank's KV cache tensor.
When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank.
When ``tp_size > num_heads``: 1 physical head per rank. Heads are
distributed **contiguously** (matching vLLM's GQA weight partitioning):
consecutive ranks share a head before moving to the next one.
"""
if
tp_size
<=
num_heads
:
assert
num_heads
%
tp_size
==
0
per_rank
=
num_heads
//
tp_size
return
range
(
rank
*
per_rank
,
(
rank
+
1
)
*
per_rank
)
else
:
h
=
rank
*
num_heads
//
tp_size
return
range
(
h
,
h
+
1
)
def
_range_overlap
(
a
:
range
,
b
:
range
)
->
range
:
start
=
max
(
a
.
start
,
b
.
start
)
stop
=
min
(
a
.
stop
,
b
.
stop
)
return
range
(
start
,
max
(
start
,
stop
))
@
dataclass
class
HeteroTPTransferConfig
:
"""Precomputed transfer plan for one (D rank, P engine) pair.
Currently only instantiated for Mamba-HMA (hybrid SSM+Attention) models
where FA and mamba require different splitting factors. Could be extended
to other model types that need non-uniform hetero-TP transfer sizing.
All descriptor sizes are computed here. The guarantee is:
local_entry_size == remote_entry_size (for NIXL)
Attributes that start with ``fa_`` concern FlashAttention KV cache.
Attributes that start with ``mamba_`` concern Mamba conv/SSM state.
"""
# ---- Input parameters (from handshake) ----
tp_ratio
:
int
K
:
int
# total_num_kv_heads (before TP sharding)
d_tp
:
int
# D engine's tensor_parallel_size
p_tp
:
int
# P engine's tensor_parallel_size
d_rank
:
int
# this D worker's TP rank
use_mla
:
bool
# Per-layer block lengths (bytes, K+V combined for blocks_first).
# Uniform across layers for current models.
d_block_len
:
int
# D's block_len_per_layer (representative)
p_block_len
:
int
# P's block_len_per_layer (from handshake)
is_blocks_first
:
bool
# kv_topo.is_kv_layout_blocks_first
# ---- Derived: computed in __post_init__ ----
#
# Physical heads per rank (what the KV tensor actually stores)
d_physical_heads
:
int
=
field
(
init
=
False
)
p_physical_heads
:
int
=
field
(
init
=
False
)
# How many distinct P ranks D needs for FA data
physical_fa_num_reads
:
int
=
field
(
init
=
False
)
# Which P ranks contribute unique FA heads (ordered by head index)
fa_read_targets
:
list
[
int
]
=
field
(
init
=
False
)
# All P ranks needed for mamba (always abs_tp for tp_ratio < 0)
mamba_num_reads
:
int
=
field
(
init
=
False
)
# All P ranks this D rank communicates with (FA ∪ mamba)
transfer_targets
:
list
[
int
]
=
field
(
init
=
False
)
# FA descriptor entry size (K or V side, for blocks_first layout)
# Guaranteed: fa_entry_size is the SAME for local handle AND remote desc.
fa_entry_size
:
int
=
field
(
init
=
False
)
# Replication flags
is_d_replicated
:
bool
=
field
(
init
=
False
)
is_p_replicated
:
bool
=
field
(
init
=
False
)
# Pre-built set for fast lookup
_fa_target_set
:
frozenset
[
int
]
=
field
(
init
=
False
,
repr
=
False
)
# Map: P rank → index in fa_read_targets (for head slot offset)
_fa_target_index
:
dict
[
int
,
int
]
=
field
(
init
=
False
,
repr
=
False
)
def
__post_init__
(
self
)
->
None
:
K
=
self
.
K
self
.
is_d_replicated
=
self
.
d_tp
>
K
self
.
is_p_replicated
=
self
.
p_tp
>
K
self
.
d_physical_heads
=
max
(
1
,
K
//
self
.
d_tp
)
self
.
p_physical_heads
=
max
(
1
,
K
//
self
.
p_tp
)
abs_tp
=
-
self
.
tp_ratio
if
self
.
tp_ratio
<
0
else
1
# ---- Mamba range (computed first so FA can prefer ranks in it) ----
mamba_range
:
range
|
None
=
None
if
self
.
tp_ratio
<
0
:
mamba_range
=
range
(
self
.
d_rank
*
abs_tp
,
(
self
.
d_rank
+
1
)
*
abs_tp
)
# ---- FA read targets ----
if
self
.
use_mla
or
self
.
tp_ratio
>=
0
:
self
.
physical_fa_num_reads
=
1
self
.
fa_read_targets
=
(
[
0
]
if
self
.
use_mla
# Must match kv_topo.get_target_remote_ranks (d_rank // tp_ratio).
else
[
self
.
d_rank
//
self
.
tp_ratio
if
self
.
tp_ratio
>
0
else
self
.
d_rank
]
)
else
:
d_needs
=
_physical_head_range
(
self
.
d_tp
,
K
,
self
.
d_rank
)
# When mamba range exists, prefer P ranks within it so that
# FA targets are a subset of mamba transfer_targets (avoids
# orphaned FA targets outside the transfer loop).
search_range
=
mamba_range
if
mamba_range
is
not
None
else
range
(
self
.
p_tp
)
seen
:
set
[
tuple
[
int
,
int
]]
=
set
()
targets
:
list
[
int
]
=
[]
for
p
in
search_range
:
p_has
=
_physical_head_range
(
self
.
p_tp
,
K
,
p
)
ov
=
_range_overlap
(
d_needs
,
p_has
)
if
len
(
ov
)
>
0
:
key
=
(
ov
.
start
,
ov
.
stop
)
if
key
not
in
seen
:
seen
.
add
(
key
)
targets
.
append
(
p
)
if
not
targets
:
# Fallback: search globally (should not happen in practice)
for
p
in
range
(
self
.
p_tp
):
p_has
=
_physical_head_range
(
self
.
p_tp
,
K
,
p
)
ov
=
_range_overlap
(
d_needs
,
p_has
)
if
len
(
ov
)
>
0
:
key
=
(
ov
.
start
,
ov
.
stop
)
if
key
not
in
seen
:
seen
.
add
(
key
)
targets
.
append
(
p
)
self
.
fa_read_targets
=
targets
self
.
physical_fa_num_reads
=
len
(
targets
)
self
.
_fa_target_set
=
frozenset
(
self
.
fa_read_targets
)
self
.
_fa_target_index
=
{
r
:
i
for
i
,
r
in
enumerate
(
self
.
fa_read_targets
)}
# ---- Mamba targets ----
if
mamba_range
is
not
None
and
abs_tp
>
self
.
physical_fa_num_reads
:
self
.
mamba_num_reads
=
abs_tp
self
.
transfer_targets
=
list
(
mamba_range
)
else
:
self
.
mamba_num_reads
=
self
.
physical_fa_num_reads
self
.
transfer_targets
=
list
(
self
.
fa_read_targets
)
# ---- FA entry size ----
# For blocks_first: block_len_per_layer includes K+V; // 2 gives K (or V).
# Use min(D, P) because D indexes into P when tp_ratio > 0,
# and P is the natural unit when tp_ratio < 0.
effective_block_len
=
min
(
self
.
d_block_len
,
self
.
p_block_len
)
if
self
.
is_blocks_first
:
self
.
fa_entry_size
=
effective_block_len
//
2
else
:
self
.
fa_entry_size
=
effective_block_len
self
.
_validate
()
def
_validate
(
self
)
->
None
:
"""Cross-check internal consistency."""
if
self
.
is_d_replicated
and
self
.
is_p_replicated
and
self
.
tp_ratio
>
0
:
logger
.
info
(
"Both-replicated hetero-TP: D_TP=%d > P_TP=%d > K=%d. "
"Using d_rank // tp_ratio routing with relative head offset."
,
self
.
d_tp
,
self
.
p_tp
,
self
.
K
,
)
# FA targets must be a subset of transfer_targets
tt_set
=
set
(
self
.
transfer_targets
)
for
t
in
self
.
fa_read_targets
:
if
t
not
in
tt_set
:
logger
.
error
(
"FA target P rank %d is NOT in transfer_targets %s. "
"This will cause missed FA reads!"
,
t
,
self
.
transfer_targets
,
)
# For tp_ratio < 0 with blocks_first: D_K_half / reads should == P_K_half
if
(
self
.
is_blocks_first
and
self
.
tp_ratio
<
0
and
self
.
physical_fa_num_reads
>
0
):
d_k_half
=
self
.
d_block_len
//
2
p_k_half
=
self
.
p_block_len
//
2
expected_local
=
d_k_half
//
self
.
physical_fa_num_reads
if
expected_local
!=
p_k_half
:
logger
.
warning
(
"FA size mismatch: D_K_half=%d / reads=%d = %d, "
"but P_K_half=%d. This may indicate a head count or "
"Mamba-HMA inflation inconsistency."
,
d_k_half
,
self
.
physical_fa_num_reads
,
expected_local
,
p_k_half
,
)
# ---- Query methods ----
def
should_skip_fa
(
self
,
p_rank
:
int
)
->
bool
:
"""Whether to skip FA groups for this P rank (mamba-only transfer)."""
return
p_rank
not
in
self
.
_fa_target_set
def
fa_head_slot
(
self
,
p_rank
:
int
)
->
int
:
"""Index into D's FA block for this P rank's head data.
For P ranks in fa_read_targets, returns 0, 1, ..., reads-1.
For P ranks NOT in fa_read_targets (replicated duplicates),
returns the slot of the matching FA target with the same head.
"""
if
p_rank
in
self
.
_fa_target_index
:
return
self
.
_fa_target_index
[
p_rank
]
# Duplicate head: find which fa_target has the same physical head
p_head
=
_physical_head_range
(
self
.
p_tp
,
self
.
K
,
p_rank
)
for
target
in
self
.
fa_read_targets
:
t_head
=
_physical_head_range
(
self
.
p_tp
,
self
.
K
,
target
)
if
_range_overlap
(
p_head
,
t_head
):
return
self
.
_fa_target_index
[
target
]
return
0
# fallback
def
fa_rank_offset
(
self
,
remote_kv_block_len
:
int
)
->
int
:
"""Byte offset into P's FA block for this D rank.
When D is replicated (D_TP > K), multiple D ranks share a head.
Computes offset *relative to the target P rank's first head*
so it works regardless of how many heads P has.
When neither side replicates, falls back to tp_rank % tp_ratio.
Returns 0 when D does not index into P's block.
"""
if
self
.
use_mla
or
self
.
tp_ratio
<=
0
:
return
0
if
self
.
is_d_replicated
:
d_head
=
self
.
d_rank
*
self
.
K
//
self
.
d_tp
p_rank
=
self
.
fa_read_targets
[
0
]
p_start
=
p_rank
*
self
.
K
//
self
.
p_tp
return
(
d_head
-
p_start
)
*
remote_kv_block_len
return
self
.
d_rank
%
self
.
tp_ratio
*
remote_kv_block_len
@
property
def
needs_split_handles
(
self
)
->
bool
:
"""Whether per-P-rank split handles are needed.
True when FA and mamba have different read counts, requiring
different splitting factors in the local handle.
"""
return
self
.
tp_ratio
<
0
and
not
self
.
use_mla
and
len
(
self
.
transfer_targets
)
>
1
def
compute_split_handle_data
(
self
,
src_blocks_data
:
list
[
tuple
[
int
,
int
,
int
]],
num_fa_descs
:
int
,
abs_tp
:
int
,
)
->
list
[
list
[
tuple
[
int
,
int
,
int
]]]:
"""Compute per-P-rank (addr, len, tp) triples for Mamba-HMA split handles.
FA descriptors (indices < num_fa_descs) are sliced by
``physical_fa_num_reads``; mamba descriptors are sliced uniformly
by ``abs_tp``.
Returns one list of triples per transfer target.
"""
all_handle_data
:
list
[
list
[
tuple
[
int
,
int
,
int
]]]
=
[]
for
p_idx
,
p_rank
in
enumerate
(
self
.
transfer_targets
):
handle_data
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
skip_fa
=
self
.
should_skip_fa
(
p_rank
)
fa_slot
=
self
.
fa_head_slot
(
p_rank
)
if
not
skip_fa
else
0
for
j
,
(
addr
,
local_len
,
tp
)
in
enumerate
(
src_blocks_data
):
if
j
<
num_fa_descs
:
assert
self
.
physical_fa_num_reads
>=
1
fa_chunk
=
local_len
//
self
.
physical_fa_num_reads
handle_data
.
append
((
addr
+
fa_slot
*
fa_chunk
,
fa_chunk
,
tp
))
else
:
mamba_chunk
=
local_len
//
abs_tp
handle_data
.
append
((
addr
+
p_idx
*
mamba_chunk
,
mamba_chunk
,
tp
))
all_handle_data
.
append
(
handle_data
)
return
all_handle_data
def
filter_block_ids_for_rank
(
self
,
remote_rank
:
int
,
local_ids
:
BlockIds
,
remote_ids
:
BlockIds
,
is_mamba_group
:
list
[
bool
],
)
->
tuple
[
BlockIds
,
BlockIds
]:
"""Zero out FA groups for P ranks outside fa_read_targets.
Returns (filtered_local_ids, filtered_remote_ids). When the
remote rank carries FA data for this D rank, returns the inputs
unchanged.
"""
if
not
self
.
should_skip_fa
(
remote_rank
):
return
local_ids
,
remote_ids
num_groups
=
len
(
local_ids
)
filtered_local
:
list
[
list
[
int
]]
=
[
[]
if
not
is_mamba_group
[
g
]
else
local_ids
[
g
]
for
g
in
range
(
num_groups
)
]
filtered_remote
:
list
[
list
[
int
]]
=
[
[]
if
not
is_mamba_group
[
g
]
else
remote_ids
[
g
]
for
g
in
range
(
num_groups
)
]
return
filtered_local
,
filtered_remote
def
describe
(
self
)
->
str
:
"""One-line summary for logging."""
return
(
f
"HeteroTPTransferConfig("
f
"tp_ratio=
{
self
.
tp_ratio
}
, K=
{
self
.
K
}
, "
f
"d_tp=
{
self
.
d_tp
}
, p_tp=
{
self
.
p_tp
}
, d_rank=
{
self
.
d_rank
}
, "
f
"physical_fa_reads=
{
self
.
physical_fa_num_reads
}
, "
f
"mamba_reads=
{
self
.
mamba_num_reads
}
, "
f
"fa_targets=
{
self
.
fa_read_targets
}
, "
f
"transfer_targets=
{
self
.
transfer_targets
}
, "
f
"fa_entry_size=
{
self
.
fa_entry_size
}
, "
f
"d_block_len=
{
self
.
d_block_len
}
, p_block_len=
{
self
.
p_block_len
}
)"
)
def
get_current_attn_backends
(
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
)
->
list
[
type
[
AttentionBackend
]]:
...
...
@@ -559,3 +891,50 @@ def get_current_attn_backend(
)
->
type
[
AttentionBackend
]:
"""Get the first attention backend for the given layers."""
return
get_current_attn_backends
(
vllm_config
,
layer_names
)[
0
]
# TODO (ZhanqiuHu): Consolidate TpKVTopology and HeteroTPTransferConfig
# into a single engine-agnostic TransferTopology class.
# 6 of 9 HeteroTPTransferConfig init fields duplicate TpKVTopology data.
#
# @dataclass
# class EngineTransferInfo:
# """Per-remote-engine transfer state, computed at handshake."""
# p_tp: int
# tp_ratio: int
# p_block_len: int
# block_size: int
# # Mamba-specific (None for non-mamba models)
# fa_read_targets: list[int] | None = None
# transfer_targets: list[int] | None = None
# physical_fa_num_reads: int | None = None
# mamba_num_reads: int | None = None
# fa_entry_size: int | None = None
#
# class TransferTopology:
# """Single source of truth for TP topology + transfer sizing."""
# # Shared (set once at init, replaces duplicate fields)
# tp_rank: int # == TpKVTopology.tp_rank == HeteroTP.d_rank
# tp_size: int # == TpKVTopology.tp_size == HeteroTP.d_tp
# total_num_kv_heads: int # == HeteroTP.K
# is_mla: bool # == HeteroTP.use_mla
# is_mamba: bool
# is_blocks_first: bool # == HeteroTP.is_blocks_first
# d_block_len: int
#
# # Per-engine (populated via register_engine() at handshake)
# _engines: dict[EngineId, EngineTransferInfo]
#
# def register_engine(self, engine_id, p_tp, p_block_len, ...): ...
#
# # General (from TpKVTopology)
# def tp_ratio(self, engine_id) -> int: ...
# def target_remote_ranks(self, engine_id) -> list[int]: ...
# def is_kv_replicated(self, engine_id) -> bool: ...
#
# # Mamba-specific (from HeteroTPTransferConfig, gated by is_mamba)
# def fa_rank_offset(self, engine_id, block_len) -> int: ...
# def physical_fa_num_reads(self, engine_id) -> int: ...
# def transfer_targets(self, engine_id) -> list[int]: ...
# def should_skip_fa(self, engine_id, p_rank) -> bool: ...
# def filter_block_ids_for_rank(self, engine_id, ...) -> ...: ...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
bfdc0a3a
This diff is collapsed.
Click to expand it.
vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py
0 → 100644
View file @
bfdc0a3a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Mamba conv-state sub-projection decomposition for the 3-read transfer.
With DS conv state layout (dim, state_len), x/B/C sub-projections are
contiguous in memory. Each D rank reads its x, B, C slices via 3
separate RDMA transfers — no P-side permutation needed.
"""
import
math
from
dataclasses
import
dataclass
import
torch
from
vllm.model_executor.layers.mamba.mamba_utils
import
is_conv_state_dim_first
from
vllm.v1.kv_cache_interface
import
MambaSpec
@
dataclass
(
frozen
=
True
)
class
MambaConvSplitInfo
:
"""Per-rank byte sizes of x, B, C sub-projections in the Mamba conv state.
Used by both P and D sides for NIXL descriptor registration.
All fields are LOCAL to this engine's TP (already divided by TP size).
DS memory layout within one page (contiguous in memory):
|--- x (x_local * conv_rows) ---|- B (b_local * conv_rows) -|- C -|
"""
conv_rows
:
int
# conv_kernel - 1 (typically 3)
x_local
:
int
# intermediate_size / TP (columns for x)
b_local
:
int
# groups_ss / TP (columns for B; C is same size)
conv_dtype_size
:
int
# bytes per element (e.g. 2 for float16)
@
property
def
conv_dim_local
(
self
)
->
int
:
"""Total conv columns per rank: x + B + C."""
return
self
.
x_local
+
2
*
self
.
b_local
@
property
def
x_bytes
(
self
)
->
int
:
"""Byte size of the x sub-projection for one rank."""
return
self
.
x_local
*
self
.
conv_rows
*
self
.
conv_dtype_size
@
property
def
b_bytes
(
self
)
->
int
:
"""Byte size of the B (or C) sub-projection for one rank."""
return
self
.
b_local
*
self
.
conv_rows
*
self
.
conv_dtype_size
@
property
def
local_conv_offsets
(
self
)
->
list
[
tuple
[
int
,
int
]]:
"""(byte_offset, byte_size) of x, B, C within this engine's page.
Used by both P and D for local descriptor registration.
"""
xb
=
self
.
x_bytes
bb
=
self
.
b_bytes
return
[(
0
,
xb
),
(
xb
,
bb
),
(
xb
+
bb
,
bb
)]
def
remote_conv_offsets
(
self
,
local_rank_offset
:
int
,
tp_ratio
:
int
)
->
list
[
tuple
[
int
,
int
]]:
"""(byte_offset, byte_size) of this D rank's x, B, C slice within
one P page.
Used by D side only, during remote descriptor registration.
Args:
local_rank_offset: which slice this D rank reads.
tp_ratio > 0: tp_rank % tp_ratio (selects slice of P's page).
tp_ratio < 0: always 0 (read P's full page).
tp_ratio: effective ratio (>= 1 when D_TP > P_TP, 1 when
P_TP > D_TP since each P rank is read in full).
"""
xb
=
self
.
x_bytes
bb
=
self
.
b_bytes
xr
=
xb
*
tp_ratio
# full remote x section in bytes
br
=
bb
*
tp_ratio
# full remote B section in bytes
return
[
(
local_rank_offset
*
xb
,
xb
),
(
xr
+
local_rank_offset
*
bb
,
bb
),
(
xr
+
br
+
local_rank_offset
*
bb
,
bb
),
]
def
derive_mamba_conv_split
(
mamba_spec
:
MambaSpec
,
local_tp
:
int
,
)
->
MambaConvSplitInfo
:
"""Derive per-rank x/B/C byte sizes from a MambaSpec.
Called once at init on both P and D. Decomposes the conv dimension
(= intermediate_size + 2 * groups_ss) into its x, B, C parts.
Args:
mamba_spec: MambaSpec whose shapes are:
shapes[0] = conv state: (conv_dim_local, conv_rows) in DS layout.
shapes[1] = SSM temporal: (local_num_heads, head_dim).
local_tp: this engine's tensor-parallel size.
Returns:
MambaConvSplitInfo with per-rank x_local, b_local, conv_rows, and
conv_dtype_size.
"""
if
mamba_spec
.
mamba_type
!=
"mamba2"
:
raise
NotImplementedError
(
f
"3-read conv transfer only supports Mamba2 models, "
f
"got mamba_type=
{
mamba_spec
.
mamba_type
!
r
}
. "
f
"Mamba1 SSM temporal shape is (intermediate_size // tp, state_size) "
f
"which cannot be used to reconstruct intermediate_size."
)
conv_shape
=
mamba_spec
.
shapes
[
0
]
assert
len
(
conv_shape
)
==
2
,
f
"Expected 2D conv state shape, got
{
conv_shape
}
"
# NOTE (ZhanqiuHu): 3-read requires DS layout, which is already asserted
# in nixl_connector __init__. Use it directly instead of heuristic detection.
assert
is_conv_state_dim_first
(),
"3-read requires DS conv state layout"
local_conv_dim
=
conv_shape
[
0
]
# DS: (conv_dim_local, conv_rows)
conv_rows
=
conv_shape
[
1
]
# NOTE (ZhanqiuHu): intermediate_size (= global x dim) is not stored
# in MambaSpec, so we reconstruct it from the SSM temporal state shape:
# shapes[1] = (local_num_heads, head_dim), already divided by TP.
head_dim
=
mamba_spec
.
shapes
[
1
][
1
]
local_num_heads
=
mamba_spec
.
shapes
[
1
][
0
]
intermediate_size
=
local_num_heads
*
local_tp
*
head_dim
# NOTE (ZhanqiuHu): global conv dim = intermediate_size + 2 * groups_ss,
# where groups_ss is the B (= C) dimension. B and C are always the same
# size, so we recover groups_ss from the remainder after subtracting x.
remainder
=
local_conv_dim
*
local_tp
-
intermediate_size
assert
remainder
>
0
and
remainder
%
2
==
0
,
(
f
"Conv dim (
{
local_conv_dim
}
*tp=
{
local_tp
}
) doesn't decompose into "
f
"intermediate_size=
{
intermediate_size
}
+ 2*groups_ss. "
f
"remainder=
{
remainder
}
"
)
groups_ss
=
remainder
//
2
conv_dtype_size
=
torch
.
tensor
(
[],
dtype
=
mamba_spec
.
dtypes
[
0
],
# type: ignore[misc]
).
element_size
()
# Divide by TP to get per-rank column counts.
return
MambaConvSplitInfo
(
conv_rows
=
conv_rows
,
x_local
=
intermediate_size
//
local_tp
,
b_local
=
groups_ss
//
local_tp
,
conv_dtype_size
=
conv_dtype_size
,
)
def
compute_mamba_phys_ratio
(
ssm_sizes
:
tuple
[
int
,
...],
block_len
:
int
)
->
int
:
"""Derive _physical_blocks_per_logical_kv_block from remote metadata.
The remote engine's ratio is not sent directly in the handshake, so we
reconstruct it: total mamba state per logical block / block_len.
Args:
ssm_sizes: (conv_state_bytes, ssm_state_bytes) from NixlAgentMetadata.
block_len: the engine's block_len in bytes (from block_lens[0]).
"""
return
math
.
ceil
((
ssm_sizes
[
0
]
+
ssm_sizes
[
1
])
/
block_len
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment