Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
41488f2a
Unverified
Commit
41488f2a
authored
Apr 15, 2026
by
zhanqiuhu
Committed by
GitHub
Apr 15, 2026
Browse files
[Bugfix][NIXL] Fix `_logical_to_kernel_block_ids` conversion for non-mamba models (#39724)
Signed-off-by:
Zhanqiu Hu
<
zhu@redhat.com
>
parent
102d51c9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
0 deletions
+97
-0
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+93
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
+4
-0
No files found.
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
View file @
41488f2a
...
...
@@ -89,6 +89,99 @@ def test_logical_to_kernel_block_ids_with_hma():
)
@
pytest
.
mark
.
cpu_test
@
pytest
.
mark
.
parametrize
(
"has_mamba,swa_enabled,mamba_enabled,remote_ratio,"
"remote_block_ids,expected_remote_block_ids"
,
[
# Non-mamba (FA+SWA): both groups expanded via _logical_to_kernel_block_ids.
# Regression for https://github.com/vllm-project/vllm/pull/39724
(
False
,
True
,
False
,
1
,
([
0
,
1
,
2
],
[
3
,
4
]),
[[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
]],
),
# Mamba (FA+Mamba): FA expanded via _logical_to_remote_kernel_block_ids,
# Mamba passed through unchanged.
# remote_ratio=261 (Nemotron 30B TP=1) != local_ratio=2 so that using
# the wrong conversion method produces different FA results.
(
True
,
False
,
True
,
261
,
([
0
,
1
,
2
],
[
10
,
11
]),
[[
0
,
1
,
261
,
262
,
522
,
523
],
[
10
,
11
]],
),
],
ids
=
[
"non_mamba_fa_swa"
,
"mamba_fa_ssm"
],
)
def
test_read_blocks_for_req_expands_remote_ids
(
has_mamba
,
swa_enabled
,
mamba_enabled
,
remote_ratio
,
remote_block_ids
,
expected_remote_block_ids
,
):
"""_read_blocks_for_req must expand remote logical block IDs to kernel
block IDs when kernel block size != logical block size.
Non-mamba path uses _logical_to_kernel_block_ids (all groups expanded).
Mamba path uses _logical_to_remote_kernel_block_ids (FA expanded, Mamba
passed through).
"""
from
unittest.mock
import
MagicMock
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata
import
(
NixlConnectorMetadata
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker
import
(
NixlConnectorWorker
,
)
worker
=
object
.
__new__
(
NixlConnectorWorker
)
worker
.
_has_mamba
=
has_mamba
worker
.
_physical_blocks_per_logical_kv_block
=
2
worker
.
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
swa_enabled
=
swa_enabled
,
mamba_enabled
=
mamba_enabled
)
remote_engine_id
=
"remote-engine"
if
has_mamba
:
worker
.
_mamba_phys_ratio
=
{
remote_engine_id
:
remote_ratio
}
# Mock kv_topo: empty remote ranks skips the transfer machinery entirely,
# isolating the block-ID expansion logic.
worker
.
kv_topo
=
MagicMock
()
worker
.
kv_topo
.
get_target_remote_ranks_from_engine_id
.
return_value
=
[]
worker
.
kv_topo
.
tp_ratio_from_engine_id
.
return_value
=
1
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
request_id
=
"test-req"
,
local_block_ids
=
([
0
,
1
],
[
2
,
3
]),
kv_transfer_params
=
{
"remote_block_ids"
:
remote_block_ids
,
"remote_engine_id"
:
remote_engine_id
,
"remote_request_id"
:
"prefill-test-req"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"tp_size"
:
1
,
},
)
meta
=
metadata
.
reqs_to_recv
[
"test-req"
]
worker
.
_read_blocks_for_req
(
"test-req"
,
meta
)
assert
meta
.
remote
.
block_ids
==
expected_remote_block_ids
,
(
f
"Expected
{
expected_remote_block_ids
}
, got
{
meta
.
remote
.
block_ids
}
"
)
@
pytest
.
mark
.
parametrize
(
"model_name, sw_size"
,
[(
"google/gemma-3-1b-it"
,
512
)])
def
test_fewer_blocks_with_hma
(
monkeypatch
,
model_name
,
sw_size
):
"""Test that a prefill instance returns fewer "remote blocks" for the SWA groups
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
View file @
41488f2a
...
...
@@ -1914,6 +1914,10 @@ class NixlConnectorWorker:
meta
.
remote
.
block_ids
,
self
.
_mamba_phys_ratio
[
meta
.
remote
.
engine_id
],
)
else
:
meta
.
remote
.
block_ids
=
self
.
_logical_to_kernel_block_ids
(
meta
.
remote
.
block_ids
)
# D may have to perform multiple reads from different remote ranks.
for
i
,
remote_rank
in
enumerate
(
remote_ranks
):
if
self
.
use_mla
and
tp_ratio
<
0
and
i
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment