Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a2ba6a52
Unverified
Commit
a2ba6a52
authored
Feb 24, 2026
by
haosdent
Committed by
GitHub
Feb 23, 2026
Browse files
[Bugfix] Fix prefix caching for Mamba 'all' mode (Nemotron models) (#34874)
Signed-off-by:
haosdent
<
haosdent@gmail.com
>
parent
c4f38696
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
166 additions
and
0 deletions
+166
-0
tests/v1/attention/test_mamba_update_block_table.py
tests/v1/attention/test_mamba_update_block_table.py
+145
-0
vllm/v1/attention/backends/mamba_attn.py
vllm/v1/attention/backends/mamba_attn.py
+21
-0
No files found.
tests/v1/attention/test_mamba_update_block_table.py
0 → 100644
View file @
a2ba6a52
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Regression test for https://github.com/vllm-project/vllm/issues/34865
When multiple KV cache groups share the same MambaSpec (as in Nemotron
hybrid models), the metadata caching optimization reuses metadata from
an earlier group via update_block_table(). In 'all' mode with CUDA graphs,
update_block_table() must copy block_idx_last_scheduled_token and
block_idx_last_computed_token to the *current* builder's persistent
buffers, otherwise CUDA graph replay reads stale values from uninitialized
buffers.
"""
from
types
import
SimpleNamespace
import
torch
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.v1.attention.backends.mamba_attn
import
(
BaseMambaAttentionMetadata
,
BaseMambaAttentionMetadataBuilder
,
)
from
vllm.v1.kv_cache_interface
import
MambaSpec
class
_ConcreteMambaBuilder
(
BaseMambaAttentionMetadataBuilder
[
BaseMambaAttentionMetadata
]
):
"""Minimal concrete subclass for testing (base class is ABC)."""
metadata_cls
=
BaseMambaAttentionMetadata
def
_make_vllm_config
(
block_size
,
max_model_len
,
max_num_seqs
):
"""Create a minimal mock VllmConfig with only the fields the builder
accesses, avoiding any model download / HF config inspection."""
return
SimpleNamespace
(
cache_config
=
SimpleNamespace
(
mamba_cache_mode
=
"all"
),
compilation_config
=
SimpleNamespace
(
cudagraph_mode
=
CUDAGraphMode
.
FULL
,
max_cudagraph_capture_size
=
None
,
),
scheduler_config
=
SimpleNamespace
(
max_num_seqs
=
max_num_seqs
),
model_config
=
SimpleNamespace
(
max_model_len
=
max_model_len
),
)
def
test_update_block_table_copies_block_idx_to_persistent_buffers
():
"""update_block_table() must write block_idx tensors to the current
builder's persistent buffers, not leave them pointing to a different
builder's buffers."""
block_size
=
16
max_model_len
=
256
num_reqs
=
4
device
=
torch
.
device
(
"cpu"
)
vllm_config
=
_make_vllm_config
(
block_size
,
max_model_len
,
num_reqs
)
spec
=
MambaSpec
(
block_size
=
block_size
,
shapes
=
((
1
,),
(
1
,)),
dtypes
=
(
torch
.
float32
,),
mamba_cache_mode
=
"all"
,
)
# Two builders simulating two KV cache groups with the same MambaSpec.
builder_a
=
_ConcreteMambaBuilder
(
spec
,
[
"layer0"
],
vllm_config
,
device
)
builder_b
=
_ConcreteMambaBuilder
(
spec
,
[
"layer1"
],
vllm_config
,
device
)
# Sanity: each builder has its own persistent buffer.
assert
(
builder_a
.
block_idx_last_scheduled_token
.
data_ptr
()
!=
builder_b
.
block_idx_last_scheduled_token
.
data_ptr
()
)
# Construct decode-only metadata as if builder_a.build() produced it.
max_blocks
=
max_model_len
//
block_size
seq_lens
=
torch
.
full
((
num_reqs
,),
64
,
dtype
=
torch
.
int32
,
device
=
device
)
block_idx_vals
=
(
seq_lens
-
1
)
//
block_size
# [3, 3, 3, 3]
builder_a
.
block_idx_last_scheduled_token
[:
num_reqs
].
copy_
(
block_idx_vals
)
builder_a
.
block_idx_last_computed_token
[:
num_reqs
].
copy_
(
block_idx_vals
)
metadata_a
=
BaseMambaAttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decodes
=
num_reqs
,
num_decode_tokens
=
num_reqs
,
num_reqs
=
num_reqs
,
has_initial_states_p
=
None
,
query_start_loc_p
=
None
,
num_computed_tokens_p
=
None
,
state_indices_tensor
=
builder_a
.
state_indices_tensor
[:
num_reqs
],
block_idx_last_scheduled_token
=
(
builder_a
.
block_idx_last_scheduled_token
[:
num_reqs
]
),
block_idx_first_scheduled_token_p
=
None
,
block_idx_last_computed_token
=
(
builder_a
.
block_idx_last_computed_token
[:
num_reqs
]
),
seq_lens
=
seq_lens
,
)
# Call update_block_table on builder_b (simulates the metadata caching
# optimization reusing metadata from builder_a's group).
blk_table
=
torch
.
randint
(
0
,
100
,
(
num_reqs
,
max_blocks
),
dtype
=
torch
.
int32
,
device
=
device
)
slot_mapping
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
metadata_b
=
builder_b
.
update_block_table
(
metadata_a
,
blk_table
,
slot_mapping
)
# block_idx tensors must live in builder_b's persistent buffers.
def
shares_storage
(
tensor
,
buffer
):
return
(
tensor
.
untyped_storage
().
data_ptr
()
==
buffer
.
untyped_storage
().
data_ptr
()
)
assert
shares_storage
(
metadata_b
.
block_idx_last_scheduled_token
,
builder_b
.
block_idx_last_scheduled_token
,
),
"block_idx_last_scheduled_token not in builder_b's persistent buffer"
assert
shares_storage
(
metadata_b
.
block_idx_last_computed_token
,
builder_b
.
block_idx_last_computed_token
,
),
"block_idx_last_computed_token not in builder_b's persistent buffer"
# Must NOT point to builder_a's buffers.
assert
not
shares_storage
(
metadata_b
.
block_idx_last_scheduled_token
,
builder_a
.
block_idx_last_scheduled_token
,
),
"block_idx_last_scheduled_token still points to builder_a's buffer"
# Values must be correct (copied from metadata_a).
torch
.
testing
.
assert_close
(
metadata_b
.
block_idx_last_scheduled_token
,
block_idx_vals
,
)
torch
.
testing
.
assert_close
(
metadata_b
.
block_idx_last_computed_token
,
block_idx_vals
,
)
vllm/v1/attention/backends/mamba_attn.py
View file @
a2ba6a52
...
@@ -331,5 +331,26 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
...
@@ -331,5 +331,26 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
persistent_state_indices_t
.
copy_
(
state_indices_t
,
non_blocking
=
True
)
persistent_state_indices_t
.
copy_
(
state_indices_t
,
non_blocking
=
True
)
state_indices_t
=
persistent_state_indices_t
state_indices_t
=
persistent_state_indices_t
# For 'all' mode, also update prefix caching block indices
# to use this builder's persistent buffers (required for CUDA
# graph replay to read from the correct memory addresses).
if
self
.
vllm_config
.
cache_config
.
mamba_cache_mode
==
"all"
:
assert
metadata
.
block_idx_last_scheduled_token
is
not
None
assert
metadata
.
block_idx_last_computed_token
is
not
None
self
.
block_idx_last_scheduled_token
[:
num_reqs
].
copy_
(
metadata
.
block_idx_last_scheduled_token
[:
num_reqs
],
non_blocking
=
True
,
)
new_metadata
.
block_idx_last_scheduled_token
=
(
self
.
block_idx_last_scheduled_token
[:
metadata
.
num_decode_tokens
]
)
self
.
block_idx_last_computed_token
[:
num_reqs
].
copy_
(
metadata
.
block_idx_last_computed_token
[:
num_reqs
],
non_blocking
=
True
,
)
new_metadata
.
block_idx_last_computed_token
=
(
self
.
block_idx_last_computed_token
[:
metadata
.
num_decode_tokens
]
)
new_metadata
.
state_indices_tensor
=
state_indices_t
new_metadata
.
state_indices_tensor
=
state_indices_t
return
new_metadata
return
new_metadata
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment