Unverified Commit 55e1a8e1 authored by Zhewen Li's avatar Zhewen Li Committed by GitHub
Browse files

[Mooncake] Fix mixed MLA+Eagle block-size validation (#39596)


Signed-off-by: default avatarZhewen Li <zhewenli@inferact.ai>
Co-authored-by: default avatarZhewen Li <zhewenli@inferact.ai>
Co-authored-by: default avatarOpenAI Codex <codex@openai.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent 21e5a9f4
......@@ -609,6 +609,51 @@ def test_register_kv_caches():
assert bl == tensor1[0].nbytes // tensor1.shape[1]
def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes():
"""Mixed MLA+Eagle caches should register by byte length, not shape."""
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector", kv_role="kv_consumer"
)
with (
set_current_vllm_config(vllm_config),
patch_worker_dependencies(),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.threading.Event"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.threading.Thread"
) as mock_thread,
):
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker
mock_thread.return_value.is_alive.return_value = False
worker.use_mla = True
worker.kv_topo.is_mla = True
# MLA cache tensor: shape[-2] is the block size.
mla_cache = torch.zeros((2, 16, 96), dtype=torch.float16)
# Eagle3/GQA-like cache tensor: shape[-2] is num_kv_heads, not block size.
eagle_cache = torch.zeros((2, 16, 8, 64), dtype=torch.float16)
kv_caches = {"mla_layer": mla_cache, "eagle_layer": eagle_cache}
with patch.object(
worker.engine, "batch_register_memory", return_value=0
) as mock_batch_register:
connector.register_kv_caches(kv_caches)
mock_batch_register.assert_called_once()
registered_ptrs, registered_lens = mock_batch_register.call_args[0]
assert registered_ptrs == [mla_cache.data_ptr(), eagle_cache.data_ptr()]
assert registered_lens == [mla_cache.nbytes, eagle_cache.nbytes]
assert worker.block_len_per_layer == [
mla_cache.nbytes // mla_cache.shape[0],
eagle_cache.nbytes // eagle_cache.shape[0],
]
@pytest.mark.asyncio
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake."
......
......@@ -23,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
get_current_attn_backend,
get_current_attn_backends,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
......@@ -47,6 +48,7 @@ from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.worker.utils import select_common_block_size
logger = init_logger(__name__)
......@@ -751,6 +753,7 @@ class MooncakeConnectorWorker:
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.use_mla = self.model_config.use_mla
self._sync_block_size_with_kernel()
# Get the attention backend from the first layer
# NOTE (NickLucche) models with multiple backends are not supported yet
......@@ -777,6 +780,23 @@ class MooncakeConnectorWorker:
self._xfer_meta_decoder = msgspec.msgpack.Decoder(MooncakeXferMetadata)
self._xfer_resp_decoder = msgspec.msgpack.Decoder(MooncakeXferResponse)
def _sync_block_size_with_kernel(self) -> None:
# When speculative decoding (e.g. Eagle) is enabled, the main model
# and draft model may use different attention backends with different
# physical block sizes. Pick the common (smallest) block size so that
# KV-cache registration and transfer work correctly for both models.
backends = get_current_attn_backends(self.vllm_config)
kernel_block_size = select_common_block_size(self.block_size, backends)
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter.",
self.block_size,
kernel_block_size,
)
assert self.block_size > kernel_block_size
self.block_size = kernel_block_size
def __del__(self):
self.shutdown()
......@@ -1268,9 +1288,6 @@ class MooncakeConnectorWorker:
self.block_len_per_layer.append(
curr_tensor_size_bytes // self.num_blocks
)
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
assert self.block_size == kernel_block_size
kv_data_ptrs.append(base_addr)
kv_data_lens.append(curr_tensor_size_bytes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment