"vllm/vscode:/vscode.git/clone" did not exist on "5b681074119b970c2f99f8baea43f856cafc0251"
Unverified Commit 4137c5df authored by haosdent's avatar haosdent Committed by GitHub
Browse files

[Bug Fix] Fix MambaManager.cache_blocks() crash on null blocks in align mode (#34418)


Signed-off-by: default avatarhaosdent <haosdent@gmail.com>
parent 7a8a46dd
...@@ -744,6 +744,12 @@ def _make_hybrid_kv_cache_config( ...@@ -744,6 +744,12 @@ def _make_hybrid_kv_cache_config(
shapes=(1, 1), shapes=(1, 1),
dtypes=(torch.float32,), dtypes=(torch.float32,),
), ),
"mamba_align": lambda: MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
mamba_cache_mode="align",
),
} }
kv_cache_groups = [ kv_cache_groups = [
...@@ -962,6 +968,46 @@ def test_prefill_hybrid_model_combinations_eagle( ...@@ -962,6 +968,46 @@ def test_prefill_hybrid_model_combinations_eagle(
manager.free(req1) manager.free(req1)
def test_prefill_hybrid_model_mamba_align():
"""Test that MambaManager.cache_blocks() handles null blocks in align mode.
Regression test for https://github.com/vllm-project/vllm/issues/34361.
In mamba_cache_mode="align", allocate_new_blocks() pads req_to_blocks with
null blocks. cache_full_blocks() correctly skips them, but
MambaManager.cache_blocks() must also skip null blocks when tracking
cached_blocks_this_step.
"""
block_size = 16
num_blocks = 30
kv_cache_config = _make_hybrid_kv_cache_config(
block_size, num_blocks, ["full", "mamba_align"]
)
manager = KVCacheManager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
hash_fn = sha256
# 3 full blocks (48 tokens) + 7 partial tokens = 55 tokens total
all_token_ids = [i for i in range(3) for _ in range(block_size)] + [3] * 7
# First request: allocate_slots should not crash with the assertion error
# in MambaManager.cache_blocks() when null blocks are present.
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, num_computed_tokens, computed_blocks)
assert blocks is not None
assert len(blocks.get_block_ids()) == 2 # full_attn + mamba groups
manager.free(req0)
def test_prefill_plp(): def test_prefill_plp():
"""Test prefill with APC and some prompt logprobs (plp) requests. """Test prefill with APC and some prompt logprobs (plp) requests.
......
...@@ -1000,6 +1000,8 @@ class MambaManager(SingleTypeKVCacheManager): ...@@ -1000,6 +1000,8 @@ class MambaManager(SingleTypeKVCacheManager):
for block in self.req_to_blocks[request.request_id][ for block in self.req_to_blocks[request.request_id][
num_cached_blocks_before:num_cached_blocks_after num_cached_blocks_before:num_cached_blocks_after
]: ]:
if block.is_null:
continue
assert block.block_hash is not None assert block.block_hash is not None
self.cached_blocks_this_step.add(block.block_hash) self.cached_blocks_this_step.add(block.block_hash)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment