Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bfdc0a3a
Unverified
Commit
bfdc0a3a
authored
Apr 06, 2026
by
zhanqiuhu
Committed by
GitHub
Apr 06, 2026
Browse files
[NIXL][Mamba][3/N] Heterogeneous TP: 3-read conv state transfer (#37635)
parent
93bada49
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
970 additions
and
75 deletions
+970
-75
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
..._connector/nixl_integration/config_sweep_accuracy_test.sh
+2
-2
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+39
-6
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+380
-1
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+385
-66
vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py
...ed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py
+164
-0
No files found.
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
View file @
bfdc0a3a
...
@@ -19,9 +19,9 @@ dp_ep_configs=(
...
@@ -19,9 +19,9 @@ dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP2, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP2, D-DPEP=2 (TP=1)
)
)
hybrid_ssm_configs
=(
hybrid_ssm_configs
=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
"
VLLM_SSM_CONV_STATE_LAYOUT=DS
ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
"
VLLM_SSM_CONV_STATE_LAYOUT=DS
ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
)
)
sw_attn_configs
=(
sw_attn_configs
=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
...
...
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
View file @
bfdc0a3a
...
@@ -224,6 +224,8 @@ def test_get_block_descs_ids_hybrid_ssm():
...
@@ -224,6 +224,8 @@ def test_get_block_descs_ids_hybrid_ssm():
worker
.
_has_mamba
=
True
worker
.
_has_mamba
=
True
worker
.
_is_mamba_group
=
[
False
,
True
]
worker
.
_is_mamba_group
=
[
False
,
True
]
worker
.
_physical_blocks_per_logical_kv_block
=
1
worker
.
_physical_blocks_per_logical_kv_block
=
1
worker
.
_mamba_phys_ratio
=
{
engine_id
:
1
}
worker
.
block_len_per_layer
=
[
100
]
# num_descs = num_regions * num_blocks (no blocks_first doubling)
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker
.
num_descs
=
2
*
num_blocks
worker
.
num_descs
=
2
*
num_blocks
...
@@ -234,9 +236,10 @@ def test_get_block_descs_ids_hybrid_ssm():
...
@@ -234,9 +236,10 @@ def test_get_block_descs_ids_hybrid_ssm():
# FA group: stride=num_blocks=100, offset=0
# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
# region0: [3, 5], region1: [103, 105]
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# offset=num_descs=200
# offset=num_fa_descs=200, 4 regions per Mamba layer (x, B, C, ssm)
# region0: [201, 202], region1: [301, 302]
# region0: [201, 202], region1: [301, 302],
expected
=
[
3
,
5
,
103
,
105
,
201
,
202
,
301
,
302
]
# region2: [401, 402], region3: [501, 502]
expected
=
[
3
,
5
,
103
,
105
,
201
,
202
,
301
,
302
,
401
,
402
,
501
,
502
]
assert
list
(
result
)
==
expected
,
f
"Expected
{
expected
}
, got
{
list
(
result
)
}
"
assert
list
(
result
)
==
expected
,
f
"Expected
{
expected
}
, got
{
list
(
result
)
}
"
...
@@ -259,6 +262,8 @@ def test_get_block_descs_ids_kernel_block_mismatch():
...
@@ -259,6 +262,8 @@ def test_get_block_descs_ids_kernel_block_mismatch():
worker
.
_has_mamba
=
True
worker
.
_has_mamba
=
True
worker
.
_is_mamba_group
=
[
False
,
True
]
worker
.
_is_mamba_group
=
[
False
,
True
]
worker
.
_physical_blocks_per_logical_kv_block
=
ratio
worker
.
_physical_blocks_per_logical_kv_block
=
ratio
worker
.
_mamba_phys_ratio
=
{
engine_id
:
ratio
}
worker
.
block_len_per_layer
=
[
100
]
worker
.
num_descs
=
2
*
num_blocks
# 800
worker
.
num_descs
=
2
*
num_blocks
# 800
fa_blocks
=
[
3
,
7
]
# kernel-level block IDs
fa_blocks
=
[
3
,
7
]
# kernel-level block IDs
...
@@ -267,9 +272,11 @@ def test_get_block_descs_ids_kernel_block_mismatch():
...
@@ -267,9 +272,11 @@ def test_get_block_descs_ids_kernel_block_mismatch():
# FA group: stride=num_blocks=400, offset=0
# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
# region0: [3, 7], region1: [403, 407]
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
# SSM group: stride=logical_blocks=400//4=100, offset=num_fa_descs=800,
# region0: [801, 802], region1: [901, 902]
# 4 regions per Mamba layer (x, B, C, ssm)
expected
=
[
3
,
7
,
403
,
407
,
801
,
802
,
901
,
902
]
# region0: [801, 802], region1: [901, 902],
# region2: [1001, 1002], region3: [1101, 1102]
expected
=
[
3
,
7
,
403
,
407
,
801
,
802
,
901
,
902
,
1001
,
1002
,
1101
,
1102
]
assert
list
(
result
)
==
expected
,
f
"Expected
{
expected
}
, got
{
list
(
result
)
}
"
assert
list
(
result
)
==
expected
,
f
"Expected
{
expected
}
, got
{
list
(
result
)
}
"
...
@@ -418,3 +425,29 @@ def test_has_mamba_init(
...
@@ -418,3 +425,29 @@ def test_has_mamba_init(
)
)
assert
scheduler
.
_has_mamba
is
expected_has_mamba
assert
scheduler
.
_has_mamba
is
expected_has_mamba
assert
scheduler
.
_is_hma_required
is
expected_is_hma
assert
scheduler
.
_is_hma_required
is
expected_is_hma
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
parametrize
(
"ssm_sizes,block_len,expected_ratio"
,
[
# Nemotron 30B TP=1: ceil((36864 + 2097152) / 8192) = 261
((
36864
,
2097152
),
8192
,
261
),
# Nemotron 30B TP=2: ceil((18432 + 1048576) / 4096) = 261
((
18432
,
1048576
),
4096
,
261
),
# Nemotron 30B TP=4: ceil((9216 + 524288) / 4096) = 131
((
9216
,
524288
),
4096
,
131
),
],
)
def
test_compute_mamba_phys_ratio
(
ssm_sizes
,
block_len
,
expected_ratio
):
"""Verify that compute_mamba_phys_ratio is TP-dependent.
With dimension-sharded Mamba state, the ratio differs across TP sizes
(e.g. TP=1 → 261, TP=4 → 131 for Nemotron 30B). This is why
_mamba_phys_ratio must be stored per-engine.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils
import
(
compute_mamba_phys_ratio
,
)
assert
compute_mamba_phys_ratio
(
ssm_sizes
,
block_len
)
==
expected_ratio
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
bfdc0a3a
...
@@ -5,7 +5,7 @@ KV cache helper for store.
...
@@ -5,7 +5,7 @@ KV cache helper for store.
"""
"""
from
collections.abc
import
Iterator
from
collections.abc
import
Iterator
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
cast
import
torch
import
torch
...
@@ -516,6 +516,338 @@ class TpKVTopology:
...
@@ -516,6 +516,338 @@ class TpKVTopology:
return
cache
if
self
.
split_k_and_v
else
[
cache
]
return
cache
if
self
.
split_k_and_v
else
[
cache
]
# ---- Mamba-HMA hetero-TP transfer config ----
#
# Key insight: with hetero-TP (P_TP > D_TP), FA KV cache may be
# replicated across P ranks (when P_TP > num_kv_heads), but Mamba
# conv/SSM state is almost always uniquely sharded per P rank. So the
# number of P ranks D must read from can differ between FA and Mamba,
# and they must be handled separately.
def
_physical_head_range
(
tp_size
:
int
,
num_heads
:
int
,
rank
:
int
)
->
range
:
"""Physical KV head range stored in a rank's KV cache tensor.
When ``tp_size <= num_heads``: sharded, K/TP contiguous heads per rank.
When ``tp_size > num_heads``: 1 physical head per rank. Heads are
distributed **contiguously** (matching vLLM's GQA weight partitioning):
consecutive ranks share a head before moving to the next one.
"""
if
tp_size
<=
num_heads
:
assert
num_heads
%
tp_size
==
0
per_rank
=
num_heads
//
tp_size
return
range
(
rank
*
per_rank
,
(
rank
+
1
)
*
per_rank
)
else
:
h
=
rank
*
num_heads
//
tp_size
return
range
(
h
,
h
+
1
)
def
_range_overlap
(
a
:
range
,
b
:
range
)
->
range
:
start
=
max
(
a
.
start
,
b
.
start
)
stop
=
min
(
a
.
stop
,
b
.
stop
)
return
range
(
start
,
max
(
start
,
stop
))
@
dataclass
class
HeteroTPTransferConfig
:
"""Precomputed transfer plan for one (D rank, P engine) pair.
Currently only instantiated for Mamba-HMA (hybrid SSM+Attention) models
where FA and mamba require different splitting factors. Could be extended
to other model types that need non-uniform hetero-TP transfer sizing.
All descriptor sizes are computed here. The guarantee is:
local_entry_size == remote_entry_size (for NIXL)
Attributes that start with ``fa_`` concern FlashAttention KV cache.
Attributes that start with ``mamba_`` concern Mamba conv/SSM state.
"""
# ---- Input parameters (from handshake) ----
tp_ratio
:
int
K
:
int
# total_num_kv_heads (before TP sharding)
d_tp
:
int
# D engine's tensor_parallel_size
p_tp
:
int
# P engine's tensor_parallel_size
d_rank
:
int
# this D worker's TP rank
use_mla
:
bool
# Per-layer block lengths (bytes, K+V combined for blocks_first).
# Uniform across layers for current models.
d_block_len
:
int
# D's block_len_per_layer (representative)
p_block_len
:
int
# P's block_len_per_layer (from handshake)
is_blocks_first
:
bool
# kv_topo.is_kv_layout_blocks_first
# ---- Derived: computed in __post_init__ ----
#
# Physical heads per rank (what the KV tensor actually stores)
d_physical_heads
:
int
=
field
(
init
=
False
)
p_physical_heads
:
int
=
field
(
init
=
False
)
# How many distinct P ranks D needs for FA data
physical_fa_num_reads
:
int
=
field
(
init
=
False
)
# Which P ranks contribute unique FA heads (ordered by head index)
fa_read_targets
:
list
[
int
]
=
field
(
init
=
False
)
# All P ranks needed for mamba (always abs_tp for tp_ratio < 0)
mamba_num_reads
:
int
=
field
(
init
=
False
)
# All P ranks this D rank communicates with (FA ∪ mamba)
transfer_targets
:
list
[
int
]
=
field
(
init
=
False
)
# FA descriptor entry size (K or V side, for blocks_first layout)
# Guaranteed: fa_entry_size is the SAME for local handle AND remote desc.
fa_entry_size
:
int
=
field
(
init
=
False
)
# Replication flags
is_d_replicated
:
bool
=
field
(
init
=
False
)
is_p_replicated
:
bool
=
field
(
init
=
False
)
# Pre-built set for fast lookup
_fa_target_set
:
frozenset
[
int
]
=
field
(
init
=
False
,
repr
=
False
)
# Map: P rank → index in fa_read_targets (for head slot offset)
_fa_target_index
:
dict
[
int
,
int
]
=
field
(
init
=
False
,
repr
=
False
)
def
__post_init__
(
self
)
->
None
:
K
=
self
.
K
self
.
is_d_replicated
=
self
.
d_tp
>
K
self
.
is_p_replicated
=
self
.
p_tp
>
K
self
.
d_physical_heads
=
max
(
1
,
K
//
self
.
d_tp
)
self
.
p_physical_heads
=
max
(
1
,
K
//
self
.
p_tp
)
abs_tp
=
-
self
.
tp_ratio
if
self
.
tp_ratio
<
0
else
1
# ---- Mamba range (computed first so FA can prefer ranks in it) ----
mamba_range
:
range
|
None
=
None
if
self
.
tp_ratio
<
0
:
mamba_range
=
range
(
self
.
d_rank
*
abs_tp
,
(
self
.
d_rank
+
1
)
*
abs_tp
)
# ---- FA read targets ----
if
self
.
use_mla
or
self
.
tp_ratio
>=
0
:
self
.
physical_fa_num_reads
=
1
self
.
fa_read_targets
=
(
[
0
]
if
self
.
use_mla
# Must match kv_topo.get_target_remote_ranks (d_rank // tp_ratio).
else
[
self
.
d_rank
//
self
.
tp_ratio
if
self
.
tp_ratio
>
0
else
self
.
d_rank
]
)
else
:
d_needs
=
_physical_head_range
(
self
.
d_tp
,
K
,
self
.
d_rank
)
# When mamba range exists, prefer P ranks within it so that
# FA targets are a subset of mamba transfer_targets (avoids
# orphaned FA targets outside the transfer loop).
search_range
=
mamba_range
if
mamba_range
is
not
None
else
range
(
self
.
p_tp
)
seen
:
set
[
tuple
[
int
,
int
]]
=
set
()
targets
:
list
[
int
]
=
[]
for
p
in
search_range
:
p_has
=
_physical_head_range
(
self
.
p_tp
,
K
,
p
)
ov
=
_range_overlap
(
d_needs
,
p_has
)
if
len
(
ov
)
>
0
:
key
=
(
ov
.
start
,
ov
.
stop
)
if
key
not
in
seen
:
seen
.
add
(
key
)
targets
.
append
(
p
)
if
not
targets
:
# Fallback: search globally (should not happen in practice)
for
p
in
range
(
self
.
p_tp
):
p_has
=
_physical_head_range
(
self
.
p_tp
,
K
,
p
)
ov
=
_range_overlap
(
d_needs
,
p_has
)
if
len
(
ov
)
>
0
:
key
=
(
ov
.
start
,
ov
.
stop
)
if
key
not
in
seen
:
seen
.
add
(
key
)
targets
.
append
(
p
)
self
.
fa_read_targets
=
targets
self
.
physical_fa_num_reads
=
len
(
targets
)
self
.
_fa_target_set
=
frozenset
(
self
.
fa_read_targets
)
self
.
_fa_target_index
=
{
r
:
i
for
i
,
r
in
enumerate
(
self
.
fa_read_targets
)}
# ---- Mamba targets ----
if
mamba_range
is
not
None
and
abs_tp
>
self
.
physical_fa_num_reads
:
self
.
mamba_num_reads
=
abs_tp
self
.
transfer_targets
=
list
(
mamba_range
)
else
:
self
.
mamba_num_reads
=
self
.
physical_fa_num_reads
self
.
transfer_targets
=
list
(
self
.
fa_read_targets
)
# ---- FA entry size ----
# For blocks_first: block_len_per_layer includes K+V; // 2 gives K (or V).
# Use min(D, P) because D indexes into P when tp_ratio > 0,
# and P is the natural unit when tp_ratio < 0.
effective_block_len
=
min
(
self
.
d_block_len
,
self
.
p_block_len
)
if
self
.
is_blocks_first
:
self
.
fa_entry_size
=
effective_block_len
//
2
else
:
self
.
fa_entry_size
=
effective_block_len
self
.
_validate
()
def
_validate
(
self
)
->
None
:
"""Cross-check internal consistency."""
if
self
.
is_d_replicated
and
self
.
is_p_replicated
and
self
.
tp_ratio
>
0
:
logger
.
info
(
"Both-replicated hetero-TP: D_TP=%d > P_TP=%d > K=%d. "
"Using d_rank // tp_ratio routing with relative head offset."
,
self
.
d_tp
,
self
.
p_tp
,
self
.
K
,
)
# FA targets must be a subset of transfer_targets
tt_set
=
set
(
self
.
transfer_targets
)
for
t
in
self
.
fa_read_targets
:
if
t
not
in
tt_set
:
logger
.
error
(
"FA target P rank %d is NOT in transfer_targets %s. "
"This will cause missed FA reads!"
,
t
,
self
.
transfer_targets
,
)
# For tp_ratio < 0 with blocks_first: D_K_half / reads should == P_K_half
if
(
self
.
is_blocks_first
and
self
.
tp_ratio
<
0
and
self
.
physical_fa_num_reads
>
0
):
d_k_half
=
self
.
d_block_len
//
2
p_k_half
=
self
.
p_block_len
//
2
expected_local
=
d_k_half
//
self
.
physical_fa_num_reads
if
expected_local
!=
p_k_half
:
logger
.
warning
(
"FA size mismatch: D_K_half=%d / reads=%d = %d, "
"but P_K_half=%d. This may indicate a head count or "
"Mamba-HMA inflation inconsistency."
,
d_k_half
,
self
.
physical_fa_num_reads
,
expected_local
,
p_k_half
,
)
# ---- Query methods ----
def
should_skip_fa
(
self
,
p_rank
:
int
)
->
bool
:
"""Whether to skip FA groups for this P rank (mamba-only transfer)."""
return
p_rank
not
in
self
.
_fa_target_set
def
fa_head_slot
(
self
,
p_rank
:
int
)
->
int
:
"""Index into D's FA block for this P rank's head data.
For P ranks in fa_read_targets, returns 0, 1, ..., reads-1.
For P ranks NOT in fa_read_targets (replicated duplicates),
returns the slot of the matching FA target with the same head.
"""
if
p_rank
in
self
.
_fa_target_index
:
return
self
.
_fa_target_index
[
p_rank
]
# Duplicate head: find which fa_target has the same physical head
p_head
=
_physical_head_range
(
self
.
p_tp
,
self
.
K
,
p_rank
)
for
target
in
self
.
fa_read_targets
:
t_head
=
_physical_head_range
(
self
.
p_tp
,
self
.
K
,
target
)
if
_range_overlap
(
p_head
,
t_head
):
return
self
.
_fa_target_index
[
target
]
return
0
# fallback
def
fa_rank_offset
(
self
,
remote_kv_block_len
:
int
)
->
int
:
"""Byte offset into P's FA block for this D rank.
When D is replicated (D_TP > K), multiple D ranks share a head.
Computes offset *relative to the target P rank's first head*
so it works regardless of how many heads P has.
When neither side replicates, falls back to tp_rank % tp_ratio.
Returns 0 when D does not index into P's block.
"""
if
self
.
use_mla
or
self
.
tp_ratio
<=
0
:
return
0
if
self
.
is_d_replicated
:
d_head
=
self
.
d_rank
*
self
.
K
//
self
.
d_tp
p_rank
=
self
.
fa_read_targets
[
0
]
p_start
=
p_rank
*
self
.
K
//
self
.
p_tp
return
(
d_head
-
p_start
)
*
remote_kv_block_len
return
self
.
d_rank
%
self
.
tp_ratio
*
remote_kv_block_len
@
property
def
needs_split_handles
(
self
)
->
bool
:
"""Whether per-P-rank split handles are needed.
True when FA and mamba have different read counts, requiring
different splitting factors in the local handle.
"""
return
self
.
tp_ratio
<
0
and
not
self
.
use_mla
and
len
(
self
.
transfer_targets
)
>
1
def
compute_split_handle_data
(
self
,
src_blocks_data
:
list
[
tuple
[
int
,
int
,
int
]],
num_fa_descs
:
int
,
abs_tp
:
int
,
)
->
list
[
list
[
tuple
[
int
,
int
,
int
]]]:
"""Compute per-P-rank (addr, len, tp) triples for Mamba-HMA split handles.
FA descriptors (indices < num_fa_descs) are sliced by
``physical_fa_num_reads``; mamba descriptors are sliced uniformly
by ``abs_tp``.
Returns one list of triples per transfer target.
"""
all_handle_data
:
list
[
list
[
tuple
[
int
,
int
,
int
]]]
=
[]
for
p_idx
,
p_rank
in
enumerate
(
self
.
transfer_targets
):
handle_data
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
skip_fa
=
self
.
should_skip_fa
(
p_rank
)
fa_slot
=
self
.
fa_head_slot
(
p_rank
)
if
not
skip_fa
else
0
for
j
,
(
addr
,
local_len
,
tp
)
in
enumerate
(
src_blocks_data
):
if
j
<
num_fa_descs
:
assert
self
.
physical_fa_num_reads
>=
1
fa_chunk
=
local_len
//
self
.
physical_fa_num_reads
handle_data
.
append
((
addr
+
fa_slot
*
fa_chunk
,
fa_chunk
,
tp
))
else
:
mamba_chunk
=
local_len
//
abs_tp
handle_data
.
append
((
addr
+
p_idx
*
mamba_chunk
,
mamba_chunk
,
tp
))
all_handle_data
.
append
(
handle_data
)
return
all_handle_data
def
filter_block_ids_for_rank
(
self
,
remote_rank
:
int
,
local_ids
:
BlockIds
,
remote_ids
:
BlockIds
,
is_mamba_group
:
list
[
bool
],
)
->
tuple
[
BlockIds
,
BlockIds
]:
"""Zero out FA groups for P ranks outside fa_read_targets.
Returns (filtered_local_ids, filtered_remote_ids). When the
remote rank carries FA data for this D rank, returns the inputs
unchanged.
"""
if
not
self
.
should_skip_fa
(
remote_rank
):
return
local_ids
,
remote_ids
num_groups
=
len
(
local_ids
)
filtered_local
:
list
[
list
[
int
]]
=
[
[]
if
not
is_mamba_group
[
g
]
else
local_ids
[
g
]
for
g
in
range
(
num_groups
)
]
filtered_remote
:
list
[
list
[
int
]]
=
[
[]
if
not
is_mamba_group
[
g
]
else
remote_ids
[
g
]
for
g
in
range
(
num_groups
)
]
return
filtered_local
,
filtered_remote
def
describe
(
self
)
->
str
:
"""One-line summary for logging."""
return
(
f
"HeteroTPTransferConfig("
f
"tp_ratio=
{
self
.
tp_ratio
}
, K=
{
self
.
K
}
, "
f
"d_tp=
{
self
.
d_tp
}
, p_tp=
{
self
.
p_tp
}
, d_rank=
{
self
.
d_rank
}
, "
f
"physical_fa_reads=
{
self
.
physical_fa_num_reads
}
, "
f
"mamba_reads=
{
self
.
mamba_num_reads
}
, "
f
"fa_targets=
{
self
.
fa_read_targets
}
, "
f
"transfer_targets=
{
self
.
transfer_targets
}
, "
f
"fa_entry_size=
{
self
.
fa_entry_size
}
, "
f
"d_block_len=
{
self
.
d_block_len
}
, p_block_len=
{
self
.
p_block_len
}
)"
)
def
get_current_attn_backends
(
def
get_current_attn_backends
(
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
)
->
list
[
type
[
AttentionBackend
]]:
)
->
list
[
type
[
AttentionBackend
]]:
...
@@ -559,3 +891,50 @@ def get_current_attn_backend(
...
@@ -559,3 +891,50 @@ def get_current_attn_backend(
)
->
type
[
AttentionBackend
]:
)
->
type
[
AttentionBackend
]:
"""Get the first attention backend for the given layers."""
"""Get the first attention backend for the given layers."""
return
get_current_attn_backends
(
vllm_config
,
layer_names
)[
0
]
return
get_current_attn_backends
(
vllm_config
,
layer_names
)[
0
]
# TODO (ZhanqiuHu): Consolidate TpKVTopology and HeteroTPTransferConfig
# into a single engine-agnostic TransferTopology class.
# 6 of 9 HeteroTPTransferConfig init fields duplicate TpKVTopology data.
#
# @dataclass
# class EngineTransferInfo:
# """Per-remote-engine transfer state, computed at handshake."""
# p_tp: int
# tp_ratio: int
# p_block_len: int
# block_size: int
# # Mamba-specific (None for non-mamba models)
# fa_read_targets: list[int] | None = None
# transfer_targets: list[int] | None = None
# physical_fa_num_reads: int | None = None
# mamba_num_reads: int | None = None
# fa_entry_size: int | None = None
#
# class TransferTopology:
# """Single source of truth for TP topology + transfer sizing."""
# # Shared (set once at init, replaces duplicate fields)
# tp_rank: int # == TpKVTopology.tp_rank == HeteroTP.d_rank
# tp_size: int # == TpKVTopology.tp_size == HeteroTP.d_tp
# total_num_kv_heads: int # == HeteroTP.K
# is_mla: bool # == HeteroTP.use_mla
# is_mamba: bool
# is_blocks_first: bool # == HeteroTP.is_blocks_first
# d_block_len: int
#
# # Per-engine (populated via register_engine() at handshake)
# _engines: dict[EngineId, EngineTransferInfo]
#
# def register_engine(self, engine_id, p_tp, p_block_len, ...): ...
#
# # General (from TpKVTopology)
# def tp_ratio(self, engine_id) -> int: ...
# def target_remote_ranks(self, engine_id) -> list[int]: ...
# def is_kv_replicated(self, engine_id) -> bool: ...
#
# # Mamba-specific (from HeteroTPTransferConfig, gated by is_mamba)
# def fa_rank_offset(self, engine_id, block_len) -> int: ...
# def physical_fa_num_reads(self, engine_id) -> int: ...
# def transfer_targets(self, engine_id) -> list[int]: ...
# def should_skip_fa(self, engine_id, p_rank) -> bool: ...
# def filter_block_ids_for_rank(self, engine_id, ...) -> ...: ...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
bfdc0a3a
...
@@ -25,6 +25,7 @@ from vllm.config import VllmConfig
...
@@ -25,6 +25,7 @@ from vllm.config import VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
BlockIds
,
BlockIds
,
EngineId
,
EngineId
,
HeteroTPTransferConfig
,
TpKVTopology
,
TpKVTopology
,
get_current_attn_backend
,
get_current_attn_backend
,
get_current_attn_backends
,
get_current_attn_backends
,
...
@@ -47,12 +48,18 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
...
@@ -47,12 +48,18 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
PromMetric
,
PromMetric
,
PromMetricT
,
PromMetricT
,
)
)
from
vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils
import
(
MambaConvSplitInfo
,
compute_mamba_phys_ratio
,
derive_mamba_conv_split
,
)
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.mamba.mamba_utils
import
is_conv_state_dim_first
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.network_utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.utils.network_utils
import
make_zmq_path
,
make_zmq_socket
...
@@ -1038,7 +1045,7 @@ class NixlConnectorWorker:
...
@@ -1038,7 +1045,7 @@ class NixlConnectorWorker:
}
}
self
.
hma_group_size
=
len
(
kv_cache_config
.
kv_cache_tensors
)
self
.
hma_group_size
=
len
(
kv_cache_config
.
kv_cache_tensors
)
# Mamba m
etadata
#
----
Mamba m
odel state (derived from model config) ----
self
.
_is_mamba_group
=
[
self
.
_is_mamba_group
=
[
isinstance
(
group
.
kv_cache_spec
,
MambaSpec
)
isinstance
(
group
.
kv_cache_spec
,
MambaSpec
)
for
group
in
kv_cache_config
.
kv_cache_groups
for
group
in
kv_cache_config
.
kv_cache_groups
...
@@ -1065,6 +1072,17 @@ class NixlConnectorWorker:
...
@@ -1065,6 +1072,17 @@ class NixlConnectorWorker:
ssm_shape
.
numel
()
*
ssm_nbytes
,
ssm_shape
.
numel
()
*
ssm_nbytes
,
)
)
self
.
_mamba_ssm_size
=
mamba_ssm_size
self
.
_mamba_ssm_size
=
mamba_ssm_size
# Conv state sub-projection decomposition (None when no Mamba).
# The 3-read transfer requires DS (dim, state_len) conv layout so
# that x/B/C sub-projections are contiguous in memory.
self
.
_conv_decomp
:
MambaConvSplitInfo
|
None
=
None
if
self
.
_has_mamba
:
assert
is_conv_state_dim_first
(),
(
"3-read Mamba conv transfer requires DS conv state layout. "
"Set VLLM_SSM_CONV_STATE_LAYOUT=DS"
)
local_tp
=
vllm_config
.
parallel_config
.
tensor_parallel_size
self
.
_conv_decomp
=
derive_mamba_conv_split
(
mamba_spec
,
local_tp
)
# Agent.
# Agent.
non_ucx_backends
=
[
b
for
b
in
self
.
nixl_backends
if
b
!=
"UCX"
]
non_ucx_backends
=
[
b
for
b
in
self
.
nixl_backends
if
b
!=
"UCX"
]
...
@@ -1175,6 +1193,16 @@ class NixlConnectorWorker:
...
@@ -1175,6 +1193,16 @@ class NixlConnectorWorker:
self
.
dst_num_blocks
:
dict
[
EngineId
,
int
]
=
{}
self
.
dst_num_blocks
:
dict
[
EngineId
,
int
]
=
{}
self
.
_registered_descs
:
list
[
Any
]
=
[]
self
.
_registered_descs
:
list
[
Any
]
=
[]
# ---- Mamba-HMA per-engine state (only used when self._has_mamba) ----
# Per-engine transfer config (source of truth for FA/mamba sizing).
self
.
_transfer_configs
:
dict
[
str
,
HeteroTPTransferConfig
]
=
{}
# NOTE (ZhanqiuHu): _mamba_phys_ratio MUST be per-engine.
# compute_mamba_phys_ratio = ceil((conv_bytes + ssm_bytes) / block_len)
# where conv/ssm bytes are per-TP-rank (dimension-sharded). With
# heterogeneous TP the per-rank sizes differ, so the ratio differs:
# e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261.
self
.
_mamba_phys_ratio
:
dict
[
EngineId
,
int
]
=
{}
# In progress transfers.
# In progress transfers.
# [req_id -> list[handle]]
# [req_id -> list[handle]]
self
.
_recving_metadata
:
dict
[
ReqId
,
ReqMeta
]
=
{}
self
.
_recving_metadata
:
dict
[
ReqId
,
ReqMeta
]
=
{}
...
@@ -1701,8 +1729,7 @@ class NixlConnectorWorker:
...
@@ -1701,8 +1729,7 @@ class NixlConnectorWorker:
# then duplicate it logically to be able to index SSM/Conv separately.
# then duplicate it logically to be able to index SSM/Conv separately.
self
.
num_regions
*=
2
self
.
num_regions
*=
2
# TODO (NickLucche) Adapt to different descs views (engine_id->tp_rank) to
# Total local FA descriptors (boundary between FA and mamba descs).
# support heterogeneous TP.
self
.
num_descs
=
self
.
num_regions
*
self
.
num_blocks
self
.
num_descs
=
self
.
num_regions
*
self
.
num_blocks
descs
=
self
.
nixl_wrapper
.
get_reg_descs
(
caches_data
,
self
.
nixl_memory_type
)
descs
=
self
.
nixl_wrapper
.
get_reg_descs
(
caches_data
,
self
.
nixl_memory_type
)
...
@@ -1715,6 +1742,9 @@ class NixlConnectorWorker:
...
@@ -1715,6 +1742,9 @@ class NixlConnectorWorker:
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
if
self
.
_has_mamba
:
if
self
.
_has_mamba
:
self
.
_mamba_phys_ratio
[
self
.
engine_id
]
=
(
self
.
_physical_blocks_per_logical_kv_block
)
logger
.
info
(
logger
.
info
(
"Hybrid SSM registration: num_blocks=%s, "
"Hybrid SSM registration: num_blocks=%s, "
"logical_num_blocks=%s, ratio=%s, num_regions=%s, "
"logical_num_blocks=%s, ratio=%s, num_regions=%s, "
...
@@ -1755,6 +1785,149 @@ class NixlConnectorWorker:
...
@@ -1755,6 +1785,149 @@ class NixlConnectorWorker:
agent_metadata_bytes
=
encoder
.
encode
(
agent_metadata
),
agent_metadata_bytes
=
encoder
.
encode
(
agent_metadata
),
)
)
def
_build_mamba_local
(
self
,
base_addresses
:
list
[
int
],
block_size_ratio
:
int
,
)
->
list
[
tuple
[
int
,
int
,
int
]]:
"""Build 4 desc regions (x, B, C, ssm) per layer for local mamba
blocks, enabling the 3-read transfer with DS conv layout."""
assert
block_size_ratio
==
1
,
(
"Mamba 3-read transfer with block_size_ratio != 1 is not tested. "
f
"Got block_size_ratio=
{
block_size_ratio
}
."
)
assert
self
.
_conv_decomp
is
not
None
conv_offsets
=
self
.
_conv_decomp
.
local_conv_offsets
conv_size
,
ssm_size
=
self
.
_mamba_ssm_size
num_blocks
=
self
.
_logical_num_blocks
*
block_size_ratio
phys_ratio
=
self
.
_physical_blocks_per_logical_kv_block
result
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
base_addr
in
enumerate
(
base_addresses
):
page_stride
=
self
.
block_len_per_layer
[
i
]
//
block_size_ratio
*
phys_ratio
for
off
,
sz
in
conv_offsets
:
for
blk
in
range
(
num_blocks
):
result
.
append
(
(
base_addr
+
blk
*
page_stride
+
off
,
sz
,
self
.
device_id
)
)
# SSM temporal state follows the conv state.
for
blk
in
range
(
num_blocks
):
result
.
append
(
(
base_addr
+
blk
*
page_stride
+
conv_size
,
ssm_size
,
self
.
device_id
,
)
)
return
result
def
_build_fa_remote_for_mamba
(
self
,
nixl_agent_meta
:
NixlAgentMetadata
,
transfer_cfg
:
HeteroTPTransferConfig
,
block_size_ratio
:
int
,
kv_topo
:
TpKVTopology
,
)
->
list
[
tuple
[
int
,
int
,
int
]]:
"""Build remote FA descriptors for mamba models.
Uses transfer_cfg for GQA-aware FA divisor and head-based rank offset
instead of the standard uniform tp_ratio split.
"""
assert
block_size_ratio
==
1
,
(
"Mamba 3-read transfer with block_size_ratio != 1 is not tested. "
f
"Got block_size_ratio=
{
block_size_ratio
}
."
)
# TODO (ZhanqiuHu): unify with register_remote_blocks when Mamba-HMA
# hetero-TP logic stabilizes.
tp_ratio
=
transfer_cfg
.
tp_ratio
result
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
local_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
,
first_split
=
True
,
mamba_view
=
False
)
remote_kv_block_len
=
local_block_len
//
block_size_ratio
if
block_size_ratio
>
1
:
local_block_len
=
remote_kv_block_len
if
tp_ratio
<
0
and
not
self
.
use_mla
:
local_block_len
=
local_block_len
//
transfer_cfg
.
physical_fa_num_reads
rank_offset
=
transfer_cfg
.
fa_rank_offset
(
remote_kv_block_len
)
num_blocks
=
nixl_agent_meta
.
num_blocks
page_size
=
nixl_agent_meta
.
block_lens
[
i
]
for
block_id
in
range
(
num_blocks
):
block_offset
=
block_id
*
page_size
addr
=
base_addr
+
block_offset
+
rank_offset
result
.
append
((
addr
,
local_block_len
,
nixl_agent_meta
.
device_id
))
if
kv_topo
.
is_kv_layout_blocks_first
:
second_split
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
,
first_split
=
False
,
mamba_view
=
False
)
if
tp_ratio
<
0
and
not
self
.
use_mla
:
second_split
=
second_split
//
transfer_cfg
.
physical_fa_num_reads
for
block_id
in
range
(
num_blocks
):
block_offset
=
block_id
*
page_size
addr
=
base_addr
+
block_offset
+
rank_offset
v_addr
=
addr
+
nixl_agent_meta
.
block_lens
[
i
]
//
2
result
.
append
((
v_addr
,
second_split
,
nixl_agent_meta
.
device_id
))
return
result
def
_build_mamba_remote
(
self
,
nixl_agent_meta
:
NixlAgentMetadata
,
tp_ratio
:
int
,
)
->
list
[
tuple
[
int
,
int
,
int
]]:
"""Build 4 remote desc regions (x, B, C, ssm) per layer for
the 3-read transfer. For hetero-TP, each D rank reads only its
sub-projection slice from the P rank."""
assert
self
.
_conv_decomp
is
not
None
effective_ratio
=
max
(
tp_ratio
,
1
)
# Mamba conv state is always TP-sharded, even when attention KV
# is replicated (num_kv_heads < tp_size).
local_offset
=
self
.
tp_rank
%
effective_ratio
conv_size_remote
=
nixl_agent_meta
.
ssm_sizes
[
0
]
if
tp_ratio
>=
1
:
# D_TP >= P_TP: P page is larger, D reads its slice.
conv_offsets
=
self
.
_conv_decomp
.
remote_conv_offsets
(
local_offset
,
effective_ratio
)
ssm_read_size
=
self
.
_mamba_ssm_size
[
1
]
else
:
# NOTE (ZhanqiuHu): tp_ratio < 0 means P_TP > D_TP, so P pages
# are smaller than D's. self._conv_decomp has D-sized dimensions,
# but we need P-sized offsets. Scale down by |tp_ratio|.
abs_ratio
=
-
tp_ratio
xb_p
=
self
.
_conv_decomp
.
x_bytes
//
abs_ratio
bb_p
=
self
.
_conv_decomp
.
b_bytes
//
abs_ratio
conv_offsets
=
[(
0
,
xb_p
),
(
xb_p
,
bb_p
),
(
xb_p
+
bb_p
,
bb_p
)]
ssm_read_size
=
nixl_agent_meta
.
ssm_sizes
[
1
]
remote_ratio
=
self
.
_mamba_phys_ratio
[
nixl_agent_meta
.
engine_id
]
num_blocks
=
nixl_agent_meta
.
num_blocks
//
remote_ratio
device_id
=
nixl_agent_meta
.
device_id
result
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
# NOTE (ZhanqiuHu): use per-layer block_lens[i], not [0], in case
# block lengths vary across layers (e.g. MLA).
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
page_stride
=
nixl_agent_meta
.
block_lens
[
i
]
*
remote_ratio
for
off
,
sz
in
conv_offsets
:
for
blk
in
range
(
num_blocks
):
result
.
append
((
base_addr
+
blk
*
page_stride
+
off
,
sz
,
device_id
))
# SSM temporal state is also TP-sharded on the heads dimension.
for
blk
in
range
(
num_blocks
):
ssm_addr
=
(
base_addr
+
blk
*
page_stride
+
conv_size_remote
+
local_offset
*
ssm_read_size
)
result
.
append
((
ssm_addr
,
ssm_read_size
,
device_id
))
return
result
def
register_local_xfer_handler
(
def
register_local_xfer_handler
(
self
,
self
,
block_size
:
int
,
block_size
:
int
,
...
@@ -1823,13 +1996,22 @@ class NixlConnectorWorker:
...
@@ -1823,13 +1996,22 @@ class NixlConnectorWorker:
self
.
device_id
,
self
.
device_id
,
)
)
# NOTE (ZhanqiuHu): mamba=True path in register_blocks is not used
# right now — we use _build_mamba_local instead for the 3-read
# approach. However, we might still need this as a fallback for homogeneous TP.
register_blocks
(
blocks_data
,
mamba
=
False
)
register_blocks
(
blocks_data
,
mamba
=
False
)
if
self
.
_has_mamba
:
if
self
.
_has_mamba
:
assert
self
.
num_descs
==
len
(
blocks_data
)
assert
self
.
num_descs
==
len
(
blocks_data
)
logger
.
debug
(
# TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the 3-read split is
"Registering additional %s local Mamba blocks"
,
len
(
blocks_data
)
# unnecessary — a single conv desc per block suffices. Consider
# adding a fast path that falls back to the standard 2-region
# registration (register_blocks mamba=True) when no hetero-TP
# remote has been seen. Currently we always register 4 regions
# because local descs are created before knowing the remote TP.
logger
.
debug
(
"Registering local Mamba descriptors (4 regions/layer)"
)
blocks_data
.
extend
(
self
.
_build_mamba_local
(
local_base_addresses
,
block_size_ratio
)
)
)
register_blocks
(
blocks_data
,
mamba
=
True
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
# NIXL_INIT_AGENT to be used for preparations of local descs.
# NIXL_INIT_AGENT to be used for preparations of local descs.
...
@@ -1880,6 +2062,9 @@ class NixlConnectorWorker:
...
@@ -1880,6 +2062,9 @@ class NixlConnectorWorker:
Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0
Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0
so that the whole cache is shared by "tp_ratio" D TP workers.
so that the whole cache is shared by "tp_ratio" D TP workers.
For Mamba hetero-TP, both tp_ratio > 0 (D_TP > P_TP) and
tp_ratio < 0 (P_TP > D_TP) are supported by the 3-read transfer.
"""
# noqa: E501
"""
# noqa: E501
engine_id
=
nixl_agent_meta
.
engine_id
engine_id
=
nixl_agent_meta
.
engine_id
# TODO re-evaluate refreshing for scaling/recovery
# TODO re-evaluate refreshing for scaling/recovery
...
@@ -1915,6 +2100,10 @@ class NixlConnectorWorker:
...
@@ -1915,6 +2100,10 @@ class NixlConnectorWorker:
if
engine_id
not
in
self
.
dst_num_blocks
:
if
engine_id
not
in
self
.
dst_num_blocks
:
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
if
self
.
_has_mamba
:
self
.
_mamba_phys_ratio
[
engine_id
]
=
compute_mamba_phys_ratio
(
nixl_agent_meta
.
ssm_sizes
,
nixl_agent_meta
.
block_lens
[
0
]
)
# Keep track of remote agent kv caches base addresses.
# Keep track of remote agent kv caches base addresses.
self
.
kv_caches_base_addr
[
engine_id
][
remote_tp_rank
]
=
(
self
.
kv_caches_base_addr
[
engine_id
][
remote_tp_rank
]
=
(
...
@@ -1931,6 +2120,21 @@ class NixlConnectorWorker:
...
@@ -1931,6 +2120,21 @@ class NixlConnectorWorker:
not
self
.
kv_topo
.
replicates_kv_cache
(
engine_id
)
and
tp_ratio
>
0
not
self
.
kv_topo
.
replicates_kv_cache
(
engine_id
)
and
tp_ratio
>
0
)
)
# Create transfer config (single source of truth for descriptor sizes).
if
self
.
_has_mamba
and
engine_id
not
in
self
.
_transfer_configs
:
self
.
_transfer_configs
[
engine_id
]
=
HeteroTPTransferConfig
(
tp_ratio
=
tp_ratio
,
K
=
kv_topo
.
total_num_kv_heads
,
d_tp
=
self
.
world_size
,
p_tp
=
remote_tp_size
,
d_rank
=
self
.
tp_rank
,
use_mla
=
self
.
use_mla
,
d_block_len
=
self
.
block_len_per_layer
[
0
],
p_block_len
=
nixl_agent_meta
.
block_lens
[
0
],
is_blocks_first
=
kv_topo
.
is_kv_layout_blocks_first
,
)
logger
.
info
(
"Created %s"
,
self
.
_transfer_configs
[
engine_id
].
describe
())
logger
.
debug
(
logger
.
debug
(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s"
,
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s"
,
engine_id
,
engine_id
,
...
@@ -1947,21 +2151,48 @@ class NixlConnectorWorker:
...
@@ -1947,21 +2151,48 @@ class NixlConnectorWorker:
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
# we only do this once per remote tp_size (replica-friendly).
abs_tp
=
-
tp_ratio
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
]
=
[]
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
]
=
[]
for
i
in
range
(
-
tp_ratio
):
blocks_data
=
[]
if
self
.
_has_mamba
:
for
memory_region
in
self
.
src_blocks_data
:
transfer_cfg
=
self
.
_transfer_configs
.
get
(
engine_id
)
addr
,
local_block_len
,
own_tp_rank
=
memory_region
assert
transfer_cfg
is
not
None
# Computing block len layer by layer allows for different
if
transfer_cfg
.
needs_split_handles
:
# block sizes to be used.
# Mamba-HMA: FA and Mamba use different split factors.
remote_block_len
=
local_block_len
//
(
-
tp_ratio
)
for
handle_data
in
transfer_cfg
.
compute_split_handle_data
(
addr
=
addr
+
i
*
remote_block_len
self
.
src_blocks_data
,
self
.
num_descs
,
abs_tp
blocks_data
.
append
((
addr
,
remote_block_len
,
own_tp_rank
))
):
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
handle_data
,
self
.
nixl_memory_type
)
)
handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
].
append
(
handle
)
"NIXL_INIT_AGENT"
,
descs
)
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
].
append
(
handle
)
logger
.
info
(
"Mamba-HMA split handles: targets=%s, fa_reads=%s, "
"fa_entry=%s, mamba_reads=%s, num_descs=%s"
,
transfer_cfg
.
transfer_targets
,
transfer_cfg
.
physical_fa_num_reads
,
transfer_cfg
.
fa_entry_size
,
transfer_cfg
.
mamba_num_reads
,
self
.
num_descs
,
)
else
:
# Original path: uniform divide by abs_tp (non-Mamba-HMA).
for
i
in
range
(
abs_tp
):
blocks_data
=
[]
for
memory_region
in
self
.
src_blocks_data
:
addr
,
local_block_len
,
own_tp_rank
=
memory_region
remote_block_len
=
local_block_len
//
abs_tp
addr
=
addr
+
i
*
remote_block_len
blocks_data
.
append
((
addr
,
remote_block_len
,
own_tp_rank
))
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
].
append
(
handle
)
### Register remote agent memory regions
### Register remote agent memory regions
blocks_data
=
[]
blocks_data
=
[]
...
@@ -2044,13 +2275,33 @@ class NixlConnectorWorker:
...
@@ -2044,13 +2275,33 @@ class NixlConnectorWorker:
self
.
tp_rank
,
self
.
tp_rank
,
)
)
register_remote_blocks
(
blocks_data
,
mamba
=
False
)
if
self
.
_has_mamba
:
if
self
.
_has_mamba
:
# Create extra descs for the Mamba "view" of the same KV cache tensors.
# Mamba-HMA: separate FA registration with GQA-aware sizing,
# plus mamba 3-read registration for the Mamba "view" of the
# same KV cache tensors.
logger
.
debug
(
logger
.
debug
(
"Registering additional %s remote Mamba blocks"
,
len
(
blocks_data
)
"Registering remote Mamba blocks for engine %s rank %s"
,
engine_id
,
remote_tp_rank
,
)
transfer_cfg
=
self
.
_transfer_configs
.
get
(
engine_id
)
assert
transfer_cfg
is
not
None
blocks_data
.
extend
(
self
.
_build_fa_remote_for_mamba
(
nixl_agent_meta
,
transfer_cfg
,
block_size_ratio
,
kv_topo
,
)
)
)
register_remote_blocks
(
blocks_data
,
mamba
=
True
)
blocks_data
.
extend
(
self
.
_build_mamba_remote
(
nixl_agent_meta
,
tp_ratio
,
)
)
else
:
register_remote_blocks
(
blocks_data
,
mamba
=
False
)
# Register with NIXL.
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
...
@@ -2083,17 +2334,17 @@ class NixlConnectorWorker:
...
@@ -2083,17 +2334,17 @@ class NixlConnectorWorker:
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
remote_engine_id
remote_engine_id
)
)
# Num kv_heads > tp_size and P TP > D TP case, not supported
# num_kv_heads > tp_size with P_TP > D_TP not supported for non-mamba.
assert
not
(
tp_ratio
<
0
and
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
))
# Mamba models can have replicated FA KV with tp_ratio < 0.
if
not
self
.
_has_mamba
:
assert
not
(
tp_ratio
<
0
and
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
)
)
if
self
.
_is_hma_required
:
if
self
.
_is_hma_required
:
assert
block_size_ratio
==
1
,
(
assert
block_size_ratio
==
1
,
(
"HMA does not support different remote block size yet"
"HMA does not support different remote block size yet"
)
)
# Mamba additional constraints
if
self
.
_has_mamba
:
assert
tp_ratio
==
1
,
"Mamba does not support heterogeneous TP yet"
kv_cache_layout
=
(
kv_cache_layout
=
(
self
.
kv_cache_layout
self
.
kv_cache_layout
if
not
self
.
use_host_buffer
if
not
self
.
use_host_buffer
...
@@ -2138,11 +2389,14 @@ class NixlConnectorWorker:
...
@@ -2138,11 +2389,14 @@ class NixlConnectorWorker:
remote_block_len
=
nixl_agent_meta
.
block_lens
[
0
]
remote_block_len
=
nixl_agent_meta
.
block_lens
[
0
]
if
self
.
use_mla
or
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
):
if
self
.
use_mla
or
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
):
# With replicated KV cache, only the number of blocks can differ.
# With replicated KV cache, only the number of blocks can differ.
for
i
in
range
(
len
(
self
.
block_len_per_layer
)):
# TODO (ZhanqiuHu): For mamba models, validate FA and mamba
assert
(
# block_lens separately.
self
.
block_len_per_layer
[
i
]
//
block_size_ratio
if
not
self
.
_has_mamba
:
==
nixl_agent_meta
.
block_lens
[
i
]
for
i
in
range
(
len
(
self
.
block_len_per_layer
)):
),
"KV cache sizes must match between P and D when replicated"
assert
(
self
.
block_len_per_layer
[
i
]
//
block_size_ratio
==
nixl_agent_meta
.
block_lens
[
i
]
),
"KV cache sizes must match between P and D when replicated"
else
:
else
:
# When MLA is not used, this is a list of the same block length
# When MLA is not used, this is a list of the same block length
for
block_len
in
nixl_agent_meta
.
block_lens
:
for
block_len
in
nixl_agent_meta
.
block_lens
:
...
@@ -2150,25 +2404,31 @@ class NixlConnectorWorker:
...
@@ -2150,25 +2404,31 @@ class NixlConnectorWorker:
"All remote layers must have the same block size"
"All remote layers must have the same block size"
)
)
if
tp_ratio
>
0
:
# HMA hybrid models (mamba+attention) pad block_len to
# Remote tp is smaller: remote block_len size is bigger
# max(attn_page, mamba_page), so the linear tp_ratio scaling
assert
(
# assumption only holds for pure-attention models.
remote_block_len
if
not
self
.
_has_mamba
:
==
(
self
.
block_len_per_layer
[
0
]
*
tp_ratio
)
//
block_size_ratio
if
tp_ratio
>
0
:
),
(
assert
(
"Remote P worker KV layer cache must be of shape [2, N, "
remote_block_len
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
==
(
self
.
block_len_per_layer
[
0
]
*
tp_ratio
)
//
block_size_ratio
)
# noqa: E501
),
(
else
:
"Remote P worker KV layer cache must be of shape [2, N,"
assert
block_size_ratio
==
1
,
(
" local_kv_heads*tp_ratio, page_size, head_dim] and "
"Different local/remote block sizes are not supported when"
"same dtype."
" P TP > D TP."
)
)
else
:
# Remote tp is bigger: remote block_len size is smaller
assert
block_size_ratio
==
1
,
(
assert
remote_block_len
==
self
.
block_len_per_layer
[
0
]
//
(
-
tp_ratio
),
(
"Different local/remote block sizes are not supported"
"Remote P worker KV layer cache must be of shape [2, N, "
" when P TP > D TP."
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
)
)
# noqa: E501
assert
remote_block_len
==
self
.
block_len_per_layer
[
0
]
//
(
-
tp_ratio
),
(
"Remote P worker KV layer cache must be of shape [2, N,"
" local_kv_heads/tp_ratio, page_size, head_dim] and "
"same dtype."
)
# TP workers that handhshake with same remote have same #blocks.
# TP workers that handhshake with same remote have same #blocks.
assert
self
.
dst_num_blocks
[
remote_engine_id
]
==
nixl_agent_meta
.
num_blocks
assert
self
.
dst_num_blocks
[
remote_engine_id
]
==
nixl_agent_meta
.
num_blocks
...
@@ -2471,9 +2731,8 @@ class NixlConnectorWorker:
...
@@ -2471,9 +2731,8 @@ class NixlConnectorWorker:
meta
.
local_block_ids
meta
.
local_block_ids
)
)
assert
meta
.
remote
is
not
None
assert
meta
.
remote
is
not
None
meta
.
remote
.
block_ids
=
self
.
_logical_to_kernel_block_ids
(
# Remote block IDs are kept logical here; expanded in
meta
.
remote
.
block_ids
# _read_blocks_for_req using the remote engine's phys ratio.
)
remote_engine_id
=
meta
.
remote
.
engine_id
remote_engine_id
=
meta
.
remote
.
engine_id
logger
.
debug
(
logger
.
debug
(
"start_load_kv for request %s from remote engine %s. "
"start_load_kv for request %s from remote engine %s. "
...
@@ -2525,6 +2784,13 @@ class NixlConnectorWorker:
...
@@ -2525,6 +2784,13 @@ class NixlConnectorWorker:
meta
.
remote
.
engine_id
meta
.
remote
.
engine_id
)
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
meta
.
remote
.
engine_id
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
meta
.
remote
.
engine_id
)
if
self
.
_has_mamba
:
# Expand remote logical → kernel block IDs.
meta
.
remote
.
block_ids
=
self
.
_logical_to_remote_kernel_block_ids
(
meta
.
remote
.
block_ids
,
self
.
_mamba_phys_ratio
[
meta
.
remote
.
engine_id
],
)
# D may have to perform multiple reads from different remote ranks.
# D may have to perform multiple reads from different remote ranks.
for
i
,
remote_rank
in
enumerate
(
remote_ranks
):
for
i
,
remote_rank
in
enumerate
(
remote_ranks
):
if
self
.
use_mla
and
tp_ratio
<
0
and
i
>
0
:
if
self
.
use_mla
and
tp_ratio
<
0
and
i
>
0
:
...
@@ -2558,12 +2824,26 @@ class NixlConnectorWorker:
...
@@ -2558,12 +2824,26 @@ class NixlConnectorWorker:
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
meta
.
remote
.
engine_id
][
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
meta
.
remote
.
engine_id
][
remote_rank
remote_rank
]
]
local_ids
:
BlockIds
=
meta
.
local_physical_block_ids
remote_ids
:
BlockIds
=
meta
.
remote
.
block_ids
if
self
.
_has_mamba
:
# Mamba-HMA: zero out FA groups for P ranks outside fa_read_targets.
transfer_cfg
=
self
.
_transfer_configs
.
get
(
meta
.
remote
.
engine_id
)
assert
transfer_cfg
is
not
None
local_ids
,
remote_ids
=
transfer_cfg
.
filter_block_ids_for_rank
(
remote_rank
,
local_ids
,
remote_ids
,
self
.
_is_mamba_group
,
)
self
.
_read_blocks
(
self
.
_read_blocks
(
request_id
=
req_id
,
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote
.
engine_id
,
dst_engine_id
=
meta
.
remote
.
engine_id
,
remote_request_id
=
meta
.
remote
.
request_id
,
remote_request_id
=
meta
.
remote
.
request_id
,
local_block_ids
=
meta
.
local_physical_block
_ids
,
local_block_ids
=
local
_ids
,
remote_block_ids
=
meta
.
remote
.
block
_ids
,
remote_block_ids
=
remote_ids
,
remote_rank
=
remote_rank
,
remote_rank
=
remote_rank
,
local_xfer_side_handle
=
local_xfer_side_handle
,
local_xfer_side_handle
=
local_xfer_side_handle
,
remote_xfer_side_handle
=
remote_xfer_side_handle
,
remote_xfer_side_handle
=
remote_xfer_side_handle
,
...
@@ -2663,9 +2943,12 @@ class NixlConnectorWorker:
...
@@ -2663,9 +2943,12 @@ class NixlConnectorWorker:
for
i
,
remote_group
in
enumerate
(
remote_block_ids
):
for
i
,
remote_group
in
enumerate
(
remote_block_ids
):
num_remote_blocks
=
len
(
remote_group
)
num_remote_blocks
=
len
(
remote_group
)
num_local_blocks
=
len
(
local_block_ids
[
i
])
num_local_blocks
=
len
(
local_block_ids
[
i
])
assert
num_local_blocks
<=
num_remote_blocks
if
not
self
.
_is_mamba_group
[
i
]:
assert
num_local_blocks
<=
num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
# Partial prefix cache hit: just read uncomputed blocks.
if
num_local_blocks
<
num_remote_blocks
:
# Skip mamba groups — their blocks represent full state (conv+ssm),
# not per-token data, so trimming would corrupt the transfer.
if
num_local_blocks
<
num_remote_blocks
and
not
self
.
_is_mamba_group
[
i
]:
remote_block_ids
[
i
]
=
remote_group
[
-
num_local_blocks
:]
remote_block_ids
[
i
]
=
remote_group
[
-
num_local_blocks
:]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
...
@@ -2781,16 +3064,22 @@ class NixlConnectorWorker:
...
@@ -2781,16 +3064,22 @@ class NixlConnectorWorker:
# This is like having two "low-level views" of the same storage.
# This is like having two "low-level views" of the same storage.
# `num_fa_descs` offset must be computed per-engine since P and D can
# `num_fa_descs` offset must be computed per-engine since P and D can
# have different num_blocks (and thus different FA descs counts).
# have different num_blocks (and thus different FA descs counts).
ratio
=
self
.
_physical_blocks_per_logical_kv_block
ratio
=
self
.
_mamba_phys_ratio
[
engine_id
]
# SSM may register fewer num_blocks than FA
logical_blocks
=
num_blocks
//
ratio
logical_blocks
=
num_blocks
//
ratio
num_fa_descs
=
self
.
num_regions
*
num_blocks
num_fa_descs
=
self
.
num_regions
*
num_blocks
# 3-read mamba: 4 regions per unique cache tensor (x, B, C, ssm).
mamba_region_ids
=
np
.
arange
(
len
(
self
.
block_len_per_layer
)
*
4
)[:,
None
]
all_descs
=
[]
all_descs
=
[]
for
i
,
group
in
enumerate
(
block_ids
):
for
i
,
group
in
enumerate
(
block_ids
):
stride
=
logical_blocks
if
self
.
_is_mamba_group
[
i
]
else
num_blocks
group_arr
=
np
.
asarray
(
group
)[
None
,
:]
group_arr
=
np
.
asarray
(
group
)[
None
,
:]
offset
=
num_fa_descs
if
self
.
_is_mamba_group
[
i
]
else
0
if
self
.
_is_mamba_group
[
i
]:
all_descs
.
append
((
region_ids
*
stride
+
group_arr
+
offset
).
flatten
())
all_descs
.
append
(
(
mamba_region_ids
*
logical_blocks
+
group_arr
+
num_fa_descs
).
flatten
()
)
else
:
all_descs
.
append
((
region_ids
*
num_blocks
+
group_arr
).
flatten
())
return
np
.
concatenate
(
all_descs
)
return
np
.
concatenate
(
all_descs
)
def
_logical_to_kernel_block_ids
(
self
,
block_ids
:
BlockIds
)
->
BlockIds
:
def
_logical_to_kernel_block_ids
(
self
,
block_ids
:
BlockIds
)
->
BlockIds
:
...
@@ -2818,6 +3107,36 @@ class NixlConnectorWorker:
...
@@ -2818,6 +3107,36 @@ class NixlConnectorWorker:
for
i
,
group
in
enumerate
(
block_ids
)
for
i
,
group
in
enumerate
(
block_ids
)
]
]
def
_logical_to_remote_kernel_block_ids
(
self
,
block_ids
:
BlockIds
,
remote_ratio
:
int
)
->
BlockIds
:
"""Map logical block IDs to physical kernel block IDs on the remote.
Args:
block_ids: per-group lists of logical block IDs.
remote_ratio: remote engine's physical blocks per logical block.
Returns:
Same structure with FA groups expanded (each logical block L
becomes kernel blocks [L*remote_ratio .. L*remote_ratio +
local_ratio - 1]). Mamba groups are passed through unchanged.
"""
local_ratio
=
self
.
_physical_blocks_per_logical_kv_block
if
remote_ratio
==
1
:
return
block_ids
local_arange
=
np
.
arange
(
local_ratio
).
reshape
(
1
,
-
1
)
group_specs
=
self
.
kv_cache_config
.
kv_cache_groups
result
:
list
[
list
[
int
]]
=
[]
for
i
,
group
in
enumerate
(
block_ids
):
if
not
isinstance
(
group_specs
[
i
].
kv_cache_spec
,
MambaSpec
):
arr
=
np
.
array
(
group
).
reshape
(
-
1
,
1
)
expanded
=
(
arr
*
remote_ratio
+
local_arange
).
flatten
()
result
.
append
(
expanded
.
tolist
())
else
:
# Mamba blocks are 1:1 logical-to-physical (no expansion).
result
.
append
(
group
)
return
result
def
get_backend_aware_kv_block_len
(
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
,
first_split
:
bool
=
True
,
mamba_view
:
bool
=
False
self
,
layer_idx
:
int
,
first_split
:
bool
=
True
,
mamba_view
:
bool
=
False
)
->
int
:
)
->
int
:
...
...
vllm/distributed/kv_transfer/kv_connector/v1/ssm_conv_transfer_utils.py
0 → 100644
View file @
bfdc0a3a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Mamba conv-state sub-projection decomposition for the 3-read transfer.
With DS conv state layout (dim, state_len), x/B/C sub-projections are
contiguous in memory. Each D rank reads its x, B, C slices via 3
separate RDMA transfers — no P-side permutation needed.
"""
import
math
from
dataclasses
import
dataclass
import
torch
from
vllm.model_executor.layers.mamba.mamba_utils
import
is_conv_state_dim_first
from
vllm.v1.kv_cache_interface
import
MambaSpec
@
dataclass
(
frozen
=
True
)
class
MambaConvSplitInfo
:
"""Per-rank byte sizes of x, B, C sub-projections in the Mamba conv state.
Used by both P and D sides for NIXL descriptor registration.
All fields are LOCAL to this engine's TP (already divided by TP size).
DS memory layout within one page (contiguous in memory):
|--- x (x_local * conv_rows) ---|- B (b_local * conv_rows) -|- C -|
"""
conv_rows
:
int
# conv_kernel - 1 (typically 3)
x_local
:
int
# intermediate_size / TP (columns for x)
b_local
:
int
# groups_ss / TP (columns for B; C is same size)
conv_dtype_size
:
int
# bytes per element (e.g. 2 for float16)
@
property
def
conv_dim_local
(
self
)
->
int
:
"""Total conv columns per rank: x + B + C."""
return
self
.
x_local
+
2
*
self
.
b_local
@
property
def
x_bytes
(
self
)
->
int
:
"""Byte size of the x sub-projection for one rank."""
return
self
.
x_local
*
self
.
conv_rows
*
self
.
conv_dtype_size
@
property
def
b_bytes
(
self
)
->
int
:
"""Byte size of the B (or C) sub-projection for one rank."""
return
self
.
b_local
*
self
.
conv_rows
*
self
.
conv_dtype_size
@
property
def
local_conv_offsets
(
self
)
->
list
[
tuple
[
int
,
int
]]:
"""(byte_offset, byte_size) of x, B, C within this engine's page.
Used by both P and D for local descriptor registration.
"""
xb
=
self
.
x_bytes
bb
=
self
.
b_bytes
return
[(
0
,
xb
),
(
xb
,
bb
),
(
xb
+
bb
,
bb
)]
def
remote_conv_offsets
(
self
,
local_rank_offset
:
int
,
tp_ratio
:
int
)
->
list
[
tuple
[
int
,
int
]]:
"""(byte_offset, byte_size) of this D rank's x, B, C slice within
one P page.
Used by D side only, during remote descriptor registration.
Args:
local_rank_offset: which slice this D rank reads.
tp_ratio > 0: tp_rank % tp_ratio (selects slice of P's page).
tp_ratio < 0: always 0 (read P's full page).
tp_ratio: effective ratio (>= 1 when D_TP > P_TP, 1 when
P_TP > D_TP since each P rank is read in full).
"""
xb
=
self
.
x_bytes
bb
=
self
.
b_bytes
xr
=
xb
*
tp_ratio
# full remote x section in bytes
br
=
bb
*
tp_ratio
# full remote B section in bytes
return
[
(
local_rank_offset
*
xb
,
xb
),
(
xr
+
local_rank_offset
*
bb
,
bb
),
(
xr
+
br
+
local_rank_offset
*
bb
,
bb
),
]
def
derive_mamba_conv_split
(
mamba_spec
:
MambaSpec
,
local_tp
:
int
,
)
->
MambaConvSplitInfo
:
"""Derive per-rank x/B/C byte sizes from a MambaSpec.
Called once at init on both P and D. Decomposes the conv dimension
(= intermediate_size + 2 * groups_ss) into its x, B, C parts.
Args:
mamba_spec: MambaSpec whose shapes are:
shapes[0] = conv state: (conv_dim_local, conv_rows) in DS layout.
shapes[1] = SSM temporal: (local_num_heads, head_dim).
local_tp: this engine's tensor-parallel size.
Returns:
MambaConvSplitInfo with per-rank x_local, b_local, conv_rows, and
conv_dtype_size.
"""
if
mamba_spec
.
mamba_type
!=
"mamba2"
:
raise
NotImplementedError
(
f
"3-read conv transfer only supports Mamba2 models, "
f
"got mamba_type=
{
mamba_spec
.
mamba_type
!
r
}
. "
f
"Mamba1 SSM temporal shape is (intermediate_size // tp, state_size) "
f
"which cannot be used to reconstruct intermediate_size."
)
conv_shape
=
mamba_spec
.
shapes
[
0
]
assert
len
(
conv_shape
)
==
2
,
f
"Expected 2D conv state shape, got
{
conv_shape
}
"
# NOTE (ZhanqiuHu): 3-read requires DS layout, which is already asserted
# in nixl_connector __init__. Use it directly instead of heuristic detection.
assert
is_conv_state_dim_first
(),
"3-read requires DS conv state layout"
local_conv_dim
=
conv_shape
[
0
]
# DS: (conv_dim_local, conv_rows)
conv_rows
=
conv_shape
[
1
]
# NOTE (ZhanqiuHu): intermediate_size (= global x dim) is not stored
# in MambaSpec, so we reconstruct it from the SSM temporal state shape:
# shapes[1] = (local_num_heads, head_dim), already divided by TP.
head_dim
=
mamba_spec
.
shapes
[
1
][
1
]
local_num_heads
=
mamba_spec
.
shapes
[
1
][
0
]
intermediate_size
=
local_num_heads
*
local_tp
*
head_dim
# NOTE (ZhanqiuHu): global conv dim = intermediate_size + 2 * groups_ss,
# where groups_ss is the B (= C) dimension. B and C are always the same
# size, so we recover groups_ss from the remainder after subtracting x.
remainder
=
local_conv_dim
*
local_tp
-
intermediate_size
assert
remainder
>
0
and
remainder
%
2
==
0
,
(
f
"Conv dim (
{
local_conv_dim
}
*tp=
{
local_tp
}
) doesn't decompose into "
f
"intermediate_size=
{
intermediate_size
}
+ 2*groups_ss. "
f
"remainder=
{
remainder
}
"
)
groups_ss
=
remainder
//
2
conv_dtype_size
=
torch
.
tensor
(
[],
dtype
=
mamba_spec
.
dtypes
[
0
],
# type: ignore[misc]
).
element_size
()
# Divide by TP to get per-rank column counts.
return
MambaConvSplitInfo
(
conv_rows
=
conv_rows
,
x_local
=
intermediate_size
//
local_tp
,
b_local
=
groups_ss
//
local_tp
,
conv_dtype_size
=
conv_dtype_size
,
)
def
compute_mamba_phys_ratio
(
ssm_sizes
:
tuple
[
int
,
...],
block_len
:
int
)
->
int
:
"""Derive _physical_blocks_per_logical_kv_block from remote metadata.
The remote engine's ratio is not sent directly in the handshake, so we
reconstruct it: total mamba state per logical block / block_len.
Args:
ssm_sizes: (conv_state_bytes, ssm_state_bytes) from NixlAgentMetadata.
block_len: the engine's block_len in bytes (from block_lens[0]).
"""
return
math
.
ceil
((
ssm_sizes
[
0
]
+
ssm_sizes
[
1
])
/
block_len
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment