Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
778f5541
Unverified
Commit
778f5541
authored
Oct 06, 2025
by
Thomas Parnell
Committed by
GitHub
Oct 06, 2025
Browse files
[V1] [Hybrid] Some additional clean-up in Mamba2 prefix caching (#26222)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
d3c84297
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
171 additions
and
136 deletions
+171
-136
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+96
-57
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+22
-22
vllm/v1/attention/backends/mamba2_attn.py
vllm/v1/attention/backends/mamba2_attn.py
+52
-55
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+1
-2
No files found.
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
778f5541
...
@@ -595,21 +595,32 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -595,21 +595,32 @@ class MambaMixer2(MambaBase, CustomOp):
if
prefix_caching_enabled
:
if
prefix_caching_enabled
:
# If prefix caching is enabled, retrieve the relevant variables
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
# for prefill and decode
last_state_idx_d
,
last_state_idx_p
=
torch
.
split
(
block_idx_last_computed_token_d
,
block_idx_last_computed_token_p
=
(
attn_metadata
.
last_state_idx
,
[
num_decodes
,
num_prefills
],
dim
=
0
torch
.
split
(
attn_metadata
.
block_idx_last_computed_token
,
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
)
)
current_last_idx_d
,
current_last_idx_p
=
torch
.
split
(
block_idx_last_scheduled_token_d
,
block_idx_last_scheduled_token_p
=
(
attn_metadata
.
current_last_idx
,
[
num_decodes
,
num_prefills
],
dim
=
0
torch
.
split
(
attn_metadata
.
block_idx_last_scheduled_token
,
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
)
)
# Prefill-only variables:
# Prefill-only variables:
current_first_idx_p
=
attn_metadata
.
current_first_idx_p
block_idx_first_scheduled_token_p
=
(
context_lens_p
=
attn_metadata
.
context_lens_p
attn_metadata
.
block_idx_first_scheduled_token_p
last_computed_offset_p
=
attn_metadata
.
last_computed_offset_p
)
num_computed_tokens_p
=
attn_metadata
.
num_computed_tokens_p
else
:
else
:
last_state_idx_d
,
last_state_idx_p
=
None
,
None
block_idx_last_computed_token_d
=
None
current_last_idx_d
,
current_last_idx_p
=
None
,
None
block_idx_last_computed_token_p
=
None
current_first_idx_p
=
None
block_idx_last_scheduled_token_d
=
None
context_lens_p
=
None
block_idx_last_scheduled_token_p
=
None
block_idx_first_scheduled_token_p
=
None
num_computed_tokens_p
=
None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
# and decode outputs
...
@@ -637,7 +648,8 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -637,7 +648,8 @@ class MambaMixer2(MambaBase, CustomOp):
# to by "state_indices_tensor_p".
# to by "state_indices_tensor_p".
# In particular, it will always write the state at the
# In particular, it will always write the state at the
# sequence end.
# sequence end.
# In addition, "current_first_idx_p" and "current_last_idx_p"
# In addition, "block_idx_first_scheduled_token_p" and
# "block_idx_last_scheduled_token_p"
# are provided (which are pointers into
# are provided (which are pointers into
# "state_indices_tensor_p"), it will write additional cache
# "state_indices_tensor_p"), it will write additional cache
# states aligned at "block_size_to_align".
# states aligned at "block_size_to_align".
...
@@ -652,10 +664,10 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -652,10 +664,10 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states
=
conv_state
,
conv_states
=
conv_state
,
has_initial_state
=
has_initial_states_p
,
has_initial_state
=
has_initial_states_p
,
cache_indices
=
state_indices_tensor_p
,
cache_indices
=
state_indices_tensor_p
,
current_first_idx
=
current_first_idx
_p
,
block_idx_first_scheduled_token
=
block_idx_first_scheduled_token
_p
,
current_last_idx
=
current_last_idx
_p
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token
_p
,
initial_state_idx
=
last_state_idx
_p
,
initial_state_idx
=
block_idx_last_computed_token
_p
,
context_lens
=
context_l
ens_p
,
num_computed_tokens
=
num_computed_tok
ens_p
,
block_size_to_align
=
mamba_block_size
,
block_size_to_align
=
mamba_block_size
,
metadata
=
attn_metadata
,
metadata
=
attn_metadata
,
query_start_loc
=
query_start_loc_p
,
query_start_loc
=
query_start_loc_p
,
...
@@ -669,7 +681,7 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -669,7 +681,7 @@ class MambaMixer2(MambaBase, CustomOp):
kernel_ssm_indices
=
state_indices_tensor_p
kernel_ssm_indices
=
state_indices_tensor_p
if
prefix_caching_enabled
:
if
prefix_caching_enabled
:
kernel_ssm_indices
=
state_indices_tensor_p
.
gather
(
kernel_ssm_indices
=
state_indices_tensor_p
.
gather
(
1
,
last_state_idx
_p
.
unsqueeze
(
1
)
1
,
block_idx_last_computed_token
_p
.
unsqueeze
(
1
)
).
squeeze
(
1
)
).
squeeze
(
1
)
initial_states
=
torch
.
where
(
initial_states
=
torch
.
where
(
has_initial_states_p
[:,
None
,
None
,
None
],
has_initial_states_p
[:,
None
,
None
,
None
],
...
@@ -703,52 +715,76 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -703,52 +715,76 @@ class MambaMixer2(MambaBase, CustomOp):
)
)
if
prefix_caching_enabled
:
if
prefix_caching_enabled
:
# Save states for sequences with more than just the final state:
# The chunk_stride is the number of chunks per mamba block
n_blocks_to_fill
=
current_last_idx_p
-
current_first_idx_p
# e.g., if mamba_block_size = 512 and chunk_size = 256,
for
seq_idx
in
(
n_blocks_to_fill
>
0
).
nonzero
().
squeeze
(
1
):
# then chunk_stride = 2
chunk_stride
=
mamba_block_size
//
chunk_size
# Save state for sequences with more than just final state
for
seq_idx
in
range
(
num_prefills
):
# Block index for the first scheduled token
block_idx_first_scheduled_token
=
block_idx_first_scheduled_token_p
[
seq_idx
]
# Block index for the last scheduled token
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token_p
[
seq_idx
]
# Number of blocks that need to be written
n_blocks_to_fill
=
(
block_idx_last_scheduled_token
-
block_idx_first_scheduled_token
)
# Skip sequences that don't have any blocks to fill
if
n_blocks_to_fill
==
0
:
continue
# Look up the state indices
cache_blocks_to_fill
=
state_indices_tensor_p
[
cache_blocks_to_fill
=
state_indices_tensor_p
[
seq_idx
,
seq_idx
,
current_first_idx_p
[
seq_idx
]
:
current_first_idx_p
[
seq_idx
]
block_idx_first_scheduled_token
:
block_idx_last_scheduled_token
,
+
n_blocks_to_fill
[
seq_idx
],
]
]
# chunks = [0 1 2 3 4 5 6 ...]
# First aligned chunk would typically be:
# First chunk index for this sequence
# mamba_block_size = 1024, chunk_size = 256
if
seq_idx
==
0
:
# 1024 // 256 - 1 --> chunks[3]
first_chunk
=
0
# But when last chunk wasn't block aligned:
else
:
# - last_computed_offset_p[seq_idx] // chunk_size
first_chunk
=
1
+
last_chunk_indices_p
[
seq_idx
-
1
]
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
# First chunk that is aligned on the mamba block boundary
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
first_aligned_chunk
=
first_chunk
+
chunk_stride
-
1
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
chunk_stride
=
mamba_block_size
//
chunk_size
# Calculate the number of computed tokens that were not
first_aligned_chunk
=
(
# already cached
torch
.
concat
(
num_unaligned_computed_tokens
=
(
[
num_computed_tokens_p
[
seq_idx
]
%
mamba_block_size
torch
.
zeros
(
1
,
dtype
=
last_chunk_indices_p
.
dtype
,
device
=
last_chunk_indices_p
.
device
,
),
last_chunk_indices_p
+
1
,
]
)[
seq_idx
]
+
chunk_stride
-
1
-
last_computed_offset_p
[
seq_idx
]
//
chunk_size
)
)
if
num_unaligned_computed_tokens
>
0
:
# If the number of computed tokens is not block aligned,
# then we need to shift the index accordingly
first_aligned_chunk
-=
(
num_unaligned_computed_tokens
//
chunk_size
)
# Get states to write
from_where
=
varlen_states
[
from_where
=
varlen_states
[
first_aligned_chunk
:
first_aligned_chunk
first_aligned_chunk
:
first_aligned_chunk
+
n_blocks_to_fill
[
seq_idx
]
*
chunk_stride
:
chunk_stride
+
n_blocks_to_fill
*
chunk_stride
:
chunk_stride
]
]
# Write the states
ssm_state
[
cache_blocks_to_fill
]
=
from_where
ssm_state
[
cache_blocks_to_fill
]
=
from_where
# For all seqs, store the last state (
N
ote: might be partial):
# For all seqs, store the last state (
n
ote: might be partial):
ssm_state
[
ssm_state
[
state_indices_tensor_p
.
gather
(
state_indices_tensor_p
.
gather
(
1
,
current_last_idx
_p
.
unsqueeze
(
1
)
1
,
block_idx_last_scheduled_token
_p
.
unsqueeze
(
1
)
).
squeeze
(
1
)
).
squeeze
(
1
)
]
=
varlen_states
[
last_chunk_indices_p
]
]
=
varlen_states
[
last_chunk_indices_p
]
else
:
else
:
# update ssm states
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
# - varlen state is a (num_prefills, nheads, headdim, dstate)
...
@@ -759,14 +795,17 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -759,14 +795,17 @@ class MambaMixer2(MambaBase, CustomOp):
if
has_decode
:
if
has_decode
:
if
prefix_caching_enabled
:
if
prefix_caching_enabled
:
state_indices_tensor_d_input
=
state_indices_tensor_d
.
gather
(
state_indices_tensor_d_input
=
state_indices_tensor_d
.
gather
(
1
,
last_state_idx
_d
.
unsqueeze
(
1
)
1
,
block_idx_last_computed_token
_d
.
unsqueeze
(
1
)
).
squeeze
(
1
)
).
squeeze
(
1
)
state_indices_tensor_d_output
=
state_indices_tensor_d
.
gather
(
state_indices_tensor_d_output
=
state_indices_tensor_d
.
gather
(
1
,
current_last_idx
_d
.
unsqueeze
(
1
)
1
,
block_idx_last_scheduled_token
_d
.
unsqueeze
(
1
)
).
squeeze
(
1
)
).
squeeze
(
1
)
# Note:
# for decode:
# for decode always: current_first_idx_d == current_last_idx_d
# block_idx_first_scheduled_token_d ==
# at block boundaries: current_first_idx_d > last_state_idx_d
# block_idx_last_scheduled_token_d
# at block boundaries:
# block_idx_first_scheduled_token_d >
# block_idx_last_computed_token_d
else
:
else
:
# Without caching, read and write in-place to the same blocks:
# Without caching, read and write in-place to the same blocks:
state_indices_tensor_d_input
=
state_indices_tensor_d
state_indices_tensor_d_input
=
state_indices_tensor_d
...
@@ -780,8 +819,8 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -780,8 +819,8 @@ class MambaMixer2(MambaBase, CustomOp):
self
.
conv1d
.
bias
,
self
.
conv1d
.
bias
,
self
.
activation
,
self
.
activation
,
conv_state_indices
=
state_indices_tensor_d
,
conv_state_indices
=
state_indices_tensor_d
,
current_last_idx
=
current_last_idx
_d
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token
_d
,
initial_state_idx
=
last_state_idx
_d
,
initial_state_idx
=
block_idx_last_computed_token
_d
,
)
)
hidden_states_d
,
B_d
,
C_d
=
split_hidden_states_B_C_fn
(
hidden_states_B_C_d
)
hidden_states_d
,
B_d
,
C_d
=
split_hidden_states_B_C_fn
(
hidden_states_B_C_d
)
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
778f5541
...
@@ -27,10 +27,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching
...
@@ -27,10 +27,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching
query_start_loc_ptr
,
query_start_loc_ptr
,
batch_ptr
,
batch_ptr
,
token_chunk_offset_ptr
,
token_chunk_offset_ptr
,
current_first_idx
,
# (batch,)
block_idx_first_scheduled_token
,
# (batch,)
current_last_idx
,
# (batch,)
block_idx_last_scheduled_token
,
# (batch,)
initial_state_idx
,
# (batch,)
initial_state_idx
,
# (batch,)
context_l
ens
,
# (batch,)
num_computed_tok
ens
,
# (batch,)
o_ptr
,
# (dim, seqlen) - actually pointing to x_ptr
o_ptr
,
# (dim, seqlen) - actually pointing to x_ptr
# Matrix dimensions
# Matrix dimensions
dim
:
tl
.
constexpr
,
dim
:
tl
.
constexpr
,
...
@@ -94,9 +94,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
...
@@ -94,9 +94,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"
# Get the length of the completed sequence so far and compute the offset.
# Get the length of the completed sequence so far and compute the offset.
current_first_index
=
tl
.
load
(
current_first_idx
+
idx_seq
)
current_first_index
=
tl
.
load
(
block_idx_first_scheduled_token
+
idx_seq
)
current_last_index
=
tl
.
load
(
current_last_idx
+
idx_seq
)
current_last_index
=
tl
.
load
(
block_idx_last_scheduled_token
+
idx_seq
)
sequence_completed_index
=
tl
.
load
(
context_l
ens
+
idx_seq
)
sequence_completed_index
=
tl
.
load
(
num_computed_tok
ens
+
idx_seq
)
# Compute the offset where the first stride_block_m-aligned first full block is
# Compute the offset where the first stride_block_m-aligned first full block is
# Value in "token-space"
# Value in "token-space"
...
@@ -476,10 +476,10 @@ def causal_conv1d_fn(
...
@@ -476,10 +476,10 @@ def causal_conv1d_fn(
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
activation
:
Optional
[
str
]
=
"silu"
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
current_first_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
block_idx_first_scheduled_token
:
Optional
[
torch
.
Tensor
]
=
None
,
current_last_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
block_idx_last_scheduled_token
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_state_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_state_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
context_l
ens
:
Optional
[
torch
.
Tensor
]
=
None
,
num_computed_tok
ens
:
Optional
[
torch
.
Tensor
]
=
None
,
block_size_to_align
=
0
,
block_size_to_align
=
0
,
metadata
=
None
,
metadata
=
None
,
validate_data
=
False
,
validate_data
=
False
,
...
@@ -523,13 +523,13 @@ def causal_conv1d_fn(
...
@@ -523,13 +523,13 @@ def causal_conv1d_fn(
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
in this case, the kernel will not process entries at
indices 0 and 3
indices 0 and 3
current_first_idx
: (batch,), dtype int32
block_idx_first_scheduled_token
: (batch,), dtype int32
The pointer into cache_indices, where the first cache block to be filled is located.
The pointer into cache_indices, where the first cache block to be filled is located.
current_last_idx
: (batch,), dtype int32
block_idx_last_scheduled_token
: (batch,), dtype int32
The pointer into cache_indices, where the last cache block to be filled is located.
The pointer into cache_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
initial_state_idx: (batch,), dtype int32
The pointer into cache_indices, where the cache block containing the initial state is located.
The pointer into cache_indices, where the cache block containing the initial state is located.
context_l
ens: (batch,), dtype int32
num_computed_tok
ens: (batch,), dtype int32
The number of tokens already completed for each sequence
The number of tokens already completed for each sequence
block_size_to_align: int
block_size_to_align: int
The block size to align the cached states to
The block size to align the cached states to
...
@@ -708,10 +708,10 @@ def causal_conv1d_fn(
...
@@ -708,10 +708,10 @@ def causal_conv1d_fn(
query_start_loc
,
query_start_loc
,
batch_ptr
,
batch_ptr
,
token_chunk_offset_ptr
,
token_chunk_offset_ptr
,
current_first_idx
,
block_idx_first_scheduled_token
,
current_last_idx
,
block_idx_last_scheduled_token
,
initial_state_idx
,
initial_state_idx
,
context_l
ens
,
num_computed_tok
ens
,
out
,
out
,
# Matrix dimensions
# Matrix dimensions
dim
,
dim
,
...
@@ -735,7 +735,7 @@ def causal_conv1d_fn(
...
@@ -735,7 +735,7 @@ def causal_conv1d_fn(
HAS_BIAS
=
bias
is
not
None
,
HAS_BIAS
=
bias
is
not
None
,
KERNEL_WIDTH
=
width
,
KERNEL_WIDTH
=
width
,
SILU_ACTIVATION
=
activation
in
[
"silu"
,
"swish"
],
SILU_ACTIVATION
=
activation
in
[
"silu"
,
"swish"
],
IS_APC_ENABLED
=
current_last_idx
is
not
None
,
IS_APC_ENABLED
=
block_idx_last_scheduled_token
is
not
None
,
USE_PAD_SLOT
=
pad_slot_id
is
not
None
,
USE_PAD_SLOT
=
pad_slot_id
is
not
None
,
NP2_STATELEN
=
np2_statelen
,
NP2_STATELEN
=
np2_statelen
,
# launch_cooperative_grid=True
# launch_cooperative_grid=True
...
@@ -756,7 +756,7 @@ def _causal_conv1d_update_kernel(
...
@@ -756,7 +756,7 @@ def _causal_conv1d_update_kernel(
conv_state_indices_ptr
,
conv_state_indices_ptr
,
num_accepted_tokens_ptr
,
num_accepted_tokens_ptr
,
query_start_loc_ptr
,
# (batch + 1)
query_start_loc_ptr
,
# (batch + 1)
current_last_idx
,
# (batch,)
block_idx_last_scheduled_token
,
# (batch,)
initial_state_idx
,
# (batch,)
initial_state_idx
,
# (batch,)
o_ptr
,
# (batch, dim, seqlen)
o_ptr
,
# (batch, dim, seqlen)
# Matrix dimensions
# Matrix dimensions
...
@@ -802,7 +802,7 @@ def _causal_conv1d_update_kernel(
...
@@ -802,7 +802,7 @@ def _causal_conv1d_update_kernel(
if
IS_APC_ENABLED
:
if
IS_APC_ENABLED
:
# Get the state from the initial_state_idx
# Get the state from the initial_state_idx
conv_state_init
=
tl
.
load
(
initial_state_idx
+
idx_seq
)
conv_state_init
=
tl
.
load
(
initial_state_idx
+
idx_seq
)
current_last_index
=
tl
.
load
(
current_last_idx
+
idx_seq
)
current_last_index
=
tl
.
load
(
block_idx_last_scheduled_token
+
idx_seq
)
else
:
else
:
conv_state_init
=
0
conv_state_init
=
0
current_last_index
=
0
current_last_index
=
0
...
@@ -1078,7 +1078,7 @@ def causal_conv1d_update(
...
@@ -1078,7 +1078,7 @@ def causal_conv1d_update(
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
max_query_len
:
int
=
-
1
,
max_query_len
:
int
=
-
1
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
current_last_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
block_idx_last_scheduled_token
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_state_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_state_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
validate_data
=
False
,
validate_data
=
False
,
):
):
...
@@ -1097,7 +1097,7 @@ def causal_conv1d_update(
...
@@ -1097,7 +1097,7 @@ def causal_conv1d_update(
If not None, the conv_state is a larger tensor along the batch dim,
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
Useful for a continuous batching scenario.
current_last_idx
: (batch,), dtype int32
block_idx_last_scheduled_token
: (batch,), dtype int32
The pointer into conv_state_indices, where the last cache block to be filled is located.
The pointer into conv_state_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
initial_state_idx: (batch,), dtype int32
The pointer into conv_state_indices, where the cache block containing the initial state is located.
The pointer into conv_state_indices, where the cache block containing the initial state is located.
...
@@ -1201,7 +1201,7 @@ def causal_conv1d_update(
...
@@ -1201,7 +1201,7 @@ def causal_conv1d_update(
conv_state_indices
,
conv_state_indices
,
num_accepted_tokens
,
num_accepted_tokens
,
query_start_loc
,
query_start_loc
,
current_last_idx
,
block_idx_last_scheduled_token
,
initial_state_idx
,
initial_state_idx
,
out
,
out
,
# Matrix dimensions
# Matrix dimensions
...
@@ -1230,7 +1230,7 @@ def causal_conv1d_update(
...
@@ -1230,7 +1230,7 @@ def causal_conv1d_update(
KERNEL_WIDTH
=
width
,
KERNEL_WIDTH
=
width
,
SILU_ACTIVATION
=
activation
in
[
"silu"
,
"swish"
],
SILU_ACTIVATION
=
activation
in
[
"silu"
,
"swish"
],
IS_VARLEN
=
query_start_loc
is
not
None
,
IS_VARLEN
=
query_start_loc
is
not
None
,
IS_APC_ENABLED
=
current_last_idx
is
not
None
,
IS_APC_ENABLED
=
block_idx_last_scheduled_token
is
not
None
,
IS_SPEC_DECODING
=
num_accepted_tokens
is
not
None
,
IS_SPEC_DECODING
=
num_accepted_tokens
is
not
None
,
NP2_STATELEN
=
np2_statelen
,
NP2_STATELEN
=
np2_statelen
,
USE_PAD_SLOT
=
pad_slot_id
is
not
None
,
USE_PAD_SLOT
=
pad_slot_id
is
not
None
,
...
...
vllm/v1/attention/backends/mamba2_attn.py
View file @
778f5541
...
@@ -122,11 +122,10 @@ class Mamba2AttentionMetadata:
...
@@ -122,11 +122,10 @@ class Mamba2AttentionMetadata:
last_chunk_indices_p
:
Optional
[
torch
.
Tensor
]
last_chunk_indices_p
:
Optional
[
torch
.
Tensor
]
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
current_last_idx
:
torch
.
Tensor
block_idx_last_scheduled_token
:
torch
.
Tensor
# shape: [batch,]
current_first_idx_p
:
torch
.
Tensor
block_idx_first_scheduled_token_p
:
torch
.
Tensor
# shape: [batch,]
last_state_idx
:
torch
.
Tensor
block_idx_last_computed_token
:
torch
.
Tensor
# shape: [batch,]
context_lens_p
:
torch
.
Tensor
num_computed_tokens_p
:
torch
.
Tensor
# shape: [batch,]
last_computed_offset_p
:
torch
.
Tensor
# The following attributes are for triton implementation of causal_conv1d
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
Optional
[
dict
]
=
None
nums_dict
:
Optional
[
dict
]
=
None
...
@@ -160,12 +159,12 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -160,12 +159,12 @@ class Mamba2AttentionMetadataBuilder(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
current_last_idx
=
torch
.
empty
(
self
.
block_idx_last_scheduled_token
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
,),
(
self
.
decode_cudagraph_max_bs
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
last_state_idx
=
torch
.
empty
(
self
.
block_idx_last_computed_token
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
,),
(
self
.
decode_cudagraph_max_bs
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
...
@@ -192,43 +191,38 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -192,43 +191,38 @@ class Mamba2AttentionMetadataBuilder(
# for causal_conv1d
# for causal_conv1d
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
context_lens
,
context_l
ens_p
=
None
,
None
num_computed_tokens
,
num_computed_tok
ens_p
=
None
,
None
current_first_idx
,
current_first_idx_p
=
None
,
None
block_idx_first_scheduled_token
=
None
last_computed_offset
,
last_computed_offset
_p
=
None
,
None
block_idx_first_scheduled_token
_p
=
None
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
# Return a tensor of shape (#requests, #max blocks)
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
# Additional cache-related varaiables:
# Additional cache-related varaiables:
mamba_block_size
=
self
.
kv_cache_spec
.
block_size
mamba_block_size
=
self
.
kv_cache_spec
.
block_size
seq_lens_pending
=
(
num_computed_tokens
=
common_attn_metadata
.
num_computed_tokens_cpu
.
to
(
torch
.
roll
(
common_attn_metadata
.
query_start_loc
,
-
1
,
-
1
)
self
.
device
-
common_attn_metadata
.
query_start_loc
)
)[:
-
1
]
# Block index of the last computed token
context_lens
=
common_attn_metadata
.
seq_lens
-
seq_lens_pending
block_idx_last_computed_token
=
(
last_computed_offset
=
context_lens
%
mamba_block_size
cdiv
(
num_computed_tokens
,
mamba_block_size
)
-
1
# Indices: last_computed <= current_first <= current_last
)
# Cases:
# which is <= block index for the first scheduled token
# last_computed == current_first if last state was partially
block_idx_first_scheduled_token
=
(
# computed and needs to be updated
cdiv
(
num_computed_tokens
+
1
,
mamba_block_size
)
-
1
# current_first == current_last if no block crossing occurs, and
)
# only one state will be stored
# which is <= block index of the last scheduled token
# 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]:
block_idx_last_scheduled_token
=
(
current_last_idx
=
(
cdiv
(
common_attn_metadata
.
seq_lens
,
mamba_block_size
)
-
1
cdiv
(
context_lens
+
seq_lens_pending
,
mamba_block_size
)
-
1
)
)
current_first_idx
=
cdiv
(
context_lens
+
1
,
mamba_block_size
)
-
1
last_state_idx
=
cdiv
(
context_lens
,
mamba_block_size
)
-
1
# -1 in case it's non-computed and causes later issues with indexing
# -1 in case it's non-computed and causes later issues with indexing
last_state_idx
=
last_state_idx
.
clamp
(
min
=
0
)
block_idx_last_computed_token
=
block_idx_last_computed_token
.
clamp
(
min
=
0
)
else
:
else
:
# Always return just a single block per each request:
# Always return just a single block per each request:
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
# Additional cache-related varaiables:
# Additional cache-related varaiables:
current_last_idx
=
None
block_idx_last_scheduled_token
=
None
last_state_idx
=
None
block_idx_last_computed_token
=
None
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
split_decodes_and_prefills
(
...
@@ -256,18 +250,15 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -256,18 +250,15 @@ class Mamba2AttentionMetadataBuilder(
)
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
assert
context_lens
is
not
None
assert
num_computed_tokens
is
not
None
context_lens_p
=
context_lens
[
num_reqs
-
num_prefills
:
num_reqs
]
num_computed_tokens_p
=
num_computed_tokens
[
assert
last_computed_offset
is
not
None
last_computed_offset_p
=
last_computed_offset
[
num_reqs
-
num_prefills
:
num_reqs
num_reqs
-
num_prefills
:
num_reqs
]
]
assert
current_first_idx
is
not
None
assert
block_idx_first_scheduled_token
is
not
None
current_first_idx_p
=
current_first_idx
[
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token
[
num_reqs
-
num_prefills
:
num_reqs
num_reqs
-
num_prefills
:
num_reqs
]
]
num_computed_tokens_p_cpu
=
common_attn_metadata
.
num_computed_tokens_cpu
[
num_computed_tokens_p
=
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
num_reqs
-
num_prefills
:
num_reqs
]
]
query_start_loc_p_cpu
=
(
query_start_loc_p_cpu
=
(
...
@@ -290,7 +281,7 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -290,7 +281,7 @@ class Mamba2AttentionMetadataBuilder(
last_chunk_indices
=
[]
last_chunk_indices
=
[]
seqlen_pos
=
0
seqlen_pos
=
0
for
req_idx
in
range
(
num_prefills
):
for
req_idx
in
range
(
num_prefills
):
this_num_computed
=
num_computed_tokens_p
[
req_idx
].
item
()
this_num_computed
=
num_computed_tokens_p
_cpu
[
req_idx
].
item
()
this_new_tokens
=
(
this_new_tokens
=
(
query_start_loc_p_cpu
[
req_idx
+
1
].
item
()
query_start_loc_p_cpu
[
req_idx
+
1
].
item
()
-
query_start_loc_p_cpu
[
req_idx
].
item
()
-
query_start_loc_p_cpu
[
req_idx
].
item
()
...
@@ -338,7 +329,10 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -338,7 +329,10 @@ class Mamba2AttentionMetadataBuilder(
compute_causal_conv1d_metadata
(
query_start_loc_p
)
compute_causal_conv1d_metadata
(
query_start_loc_p
)
)
)
elif
num_decodes
<=
self
.
decode_cudagraph_max_bs
:
elif
(
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
full_cuda_graph
):
# Pad state tensor for CUDA graph
# Pad state tensor for CUDA graph
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
)
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
)
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
...
@@ -348,17 +342,21 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -348,17 +342,21 @@ class Mamba2AttentionMetadataBuilder(
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
self
.
current_last_idx
[:
num_decodes
].
copy_
(
self
.
block_idx_last_scheduled_token
[:
num_decodes
].
copy_
(
current_last_idx
,
non_blocking
=
True
block_idx_last_scheduled_token
,
non_blocking
=
True
)
)
current_last_idx
=
self
.
current_last_idx
[:
num_input_tokens
]
block_idx_last_scheduled_token
=
self
.
block_idx_last_scheduled_token
[
current_last_idx
[
num_decodes
:]
=
0
:
num_input_tokens
]
block_idx_last_scheduled_token
[
num_decodes
:]
=
0
self
.
last_state_idx
[:
num_decodes
].
copy_
(
self
.
block_idx_last_computed_token
[:
num_decodes
].
copy_
(
last_state_idx
,
non_blocking
=
True
block_idx_last_computed_token
,
non_blocking
=
True
)
)
last_state_idx
=
self
.
last_state_idx
[:
num_input_tokens
]
block_idx_last_computed_token
=
self
.
block_idx_last_computed_token
[
last_state_idx
[
num_decodes
:]
=
0
:
num_input_tokens
]
block_idx_last_computed_token
[
num_decodes
:]
=
0
attn_metadata
=
Mamba2AttentionMetadata
(
attn_metadata
=
Mamba2AttentionMetadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
...
@@ -377,10 +375,9 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -377,10 +375,9 @@ class Mamba2AttentionMetadataBuilder(
nums_dict
=
nums_dict
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
current_last_idx
=
current_last_idx
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token
,
current_first_idx_p
=
current_first_idx_p
,
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token_p
,
last_state_idx
=
last_state_idx
,
block_idx_last_computed_token
=
block_idx_last_computed_token
,
context_lens_p
=
context_lens_p
,
num_computed_tokens_p
=
num_computed_tokens_p
,
last_computed_offset_p
=
last_computed_offset_p
,
)
)
return
attn_metadata
return
attn_metadata
vllm/v1/core/single_type_kv_cache_manager.py
View file @
778f5541
...
@@ -584,8 +584,7 @@ class MambaManager(SingleTypeKVCacheManager):
...
@@ -584,8 +584,7 @@ class MambaManager(SingleTypeKVCacheManager):
# hit_length = len(hit_blocks_other_attn[0])
# hit_length = len(hit_blocks_other_attn[0])
# * self.other_block_size
# * self.other_block_size
# so we insert dummy blocks at the beginning:
# so we insert dummy blocks at the beginning:
if
i
>
0
:
computed
.
extend
([
block_pool
.
null_block
]
*
i
)
computed
.
extend
([
block_pool
.
null_block
]
*
i
)
computed
.
append
(
cached
)
computed
.
append
(
cached
)
break
# we just need the last match - early stopping
break
# we just need the last match - early stopping
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment