Unverified Commit 41488f2a authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

[Bugfix][NIXL] Fix `_logical_to_kernel_block_ids` conversion for non-mamba models (#39724)


Signed-off-by: default avatarZhanqiu Hu <zhu@redhat.com>
parent 102d51c9
......@@ -89,6 +89,99 @@ def test_logical_to_kernel_block_ids_with_hma():
)
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"has_mamba,swa_enabled,mamba_enabled,remote_ratio,"
"remote_block_ids,expected_remote_block_ids",
[
# Non-mamba (FA+SWA): both groups expanded via _logical_to_kernel_block_ids.
# Regression for https://github.com/vllm-project/vllm/pull/39724
(
False,
True,
False,
1,
([0, 1, 2], [3, 4]),
[[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]],
),
# Mamba (FA+Mamba): FA expanded via _logical_to_remote_kernel_block_ids,
# Mamba passed through unchanged.
# remote_ratio=261 (Nemotron 30B TP=1) != local_ratio=2 so that using
# the wrong conversion method produces different FA results.
(
True,
False,
True,
261,
([0, 1, 2], [10, 11]),
[[0, 1, 261, 262, 522, 523], [10, 11]],
),
],
ids=["non_mamba_fa_swa", "mamba_fa_ssm"],
)
def test_read_blocks_for_req_expands_remote_ids(
has_mamba,
swa_enabled,
mamba_enabled,
remote_ratio,
remote_block_ids,
expected_remote_block_ids,
):
"""_read_blocks_for_req must expand remote logical block IDs to kernel
block IDs when kernel block size != logical block size.
Non-mamba path uses _logical_to_kernel_block_ids (all groups expanded).
Mamba path uses _logical_to_remote_kernel_block_ids (FA expanded, Mamba
passed through).
"""
from unittest.mock import MagicMock
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
worker = object.__new__(NixlConnectorWorker)
worker._has_mamba = has_mamba
worker._physical_blocks_per_logical_kv_block = 2
worker.kv_cache_config = make_kv_cache_config(
block_size=16, swa_enabled=swa_enabled, mamba_enabled=mamba_enabled
)
remote_engine_id = "remote-engine"
if has_mamba:
worker._mamba_phys_ratio = {remote_engine_id: remote_ratio}
# Mock kv_topo: empty remote ranks skips the transfer machinery entirely,
# isolating the block-ID expansion logic.
worker.kv_topo = MagicMock()
worker.kv_topo.get_target_remote_ranks_from_engine_id.return_value = []
worker.kv_topo.tp_ratio_from_engine_id.return_value = 1
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id="test-req",
local_block_ids=([0, 1], [2, 3]),
kv_transfer_params={
"remote_block_ids": remote_block_ids,
"remote_engine_id": remote_engine_id,
"remote_request_id": "prefill-test-req",
"remote_host": "localhost",
"remote_port": 1234,
"tp_size": 1,
},
)
meta = metadata.reqs_to_recv["test-req"]
worker._read_blocks_for_req("test-req", meta)
assert meta.remote.block_ids == expected_remote_block_ids, (
f"Expected {expected_remote_block_ids}, got {meta.remote.block_ids}"
)
@pytest.mark.parametrize("model_name, sw_size", [("google/gemma-3-1b-it", 512)])
def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size):
"""Test that a prefill instance returns fewer "remote blocks" for the SWA groups
......
......@@ -1914,6 +1914,10 @@ class NixlConnectorWorker:
meta.remote.block_ids,
self._mamba_phys_ratio[meta.remote.engine_id],
)
else:
meta.remote.block_ids = self._logical_to_kernel_block_ids(
meta.remote.block_ids
)
# D may have to perform multiple reads from different remote ranks.
for i, remote_rank in enumerate(remote_ranks):
if self.use_mla and tp_ratio < 0 and i > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment