Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
34916ae3
Unverified
Commit
34916ae3
authored
Dec 23, 2025
by
Asaf Joseph Gardin
Committed by
GitHub
Dec 23, 2025
Browse files
[Mamba] - Consolidate Mambas Attention Logic (#28133)
parent
0736f901
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
305 additions
and
448 deletions
+305
-448
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+1
-5
vllm/v1/attention/backends/mamba1_attn.py
vllm/v1/attention/backends/mamba1_attn.py
+7
-138
vllm/v1/attention/backends/mamba2_attn.py
vllm/v1/attention/backends/mamba2_attn.py
+100
-220
vllm/v1/attention/backends/mamba_attn.py
vllm/v1/attention/backends/mamba_attn.py
+191
-1
vllm/v1/attention/backends/short_conv_attn.py
vllm/v1/attention/backends/short_conv_attn.py
+6
-84
No files found.
vllm/model_executor/layers/mamba/short_conv.py
View file @
34916ae3
...
...
@@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp):
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
BCx
,
_
=
self
.
in_proj
(
hidden_states
)
...
...
@@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp):
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
conv_output_list
=
[]
...
...
vllm/v1/attention/backends/mamba1_attn.py
View file @
34916ae3
...
...
@@ -3,17 +3,11 @@
from
dataclasses
import
dataclass
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.config
import
VllmConfig
from
vllm.v1.attention.backends.mamba_attn
import
BaseMambaAttentionMetadataBuilder
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
split_decodes_and_prefills
,
from
vllm.v1.attention.backends.mamba_attn
import
(
BaseMambaAttentionMetadata
,
BaseMambaAttentionMetadataBuilder
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
class
Mamba1AttentionBackend
(
AttentionBackend
):
...
...
@@ -23,137 +17,12 @@ class Mamba1AttentionBackend(AttentionBackend):
@
dataclass
class
Mamba1AttentionMetadata
:
query_start_loc_p
:
torch
.
Tensor
state_indices_tensor
:
torch
.
Tensor
has_initial_states_p
:
torch
.
Tensor
|
None
num_prefills
:
int
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
block_idx_last_scheduled_token
:
torch
.
Tensor
# shape: [batch,]
block_idx_first_scheduled_token_p
:
torch
.
Tensor
# shape: [batch,]
block_idx_last_computed_token
:
torch
.
Tensor
# shape: [batch,]
num_computed_tokens_p
:
torch
.
Tensor
# shape: [batch,]
class
Mamba1AttentionMetadata
(
BaseMambaAttentionMetadata
):
pass
class
Mamba1AttentionMetadataBuilder
(
BaseMambaAttentionMetadataBuilder
[
Mamba1AttentionMetadata
]
):
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
assert
isinstance
(
kv_cache_spec
,
MambaSpec
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
Mamba1AttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
)
)
has_initial_states_p
=
None
query_start_loc_p
=
None
num_computed_tokens
,
num_computed_tokens_p
=
None
,
None
block_idx_first_scheduled_token
=
None
block_idx_first_scheduled_token_p
=
None
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
# We should consolidate this code
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
mamba_block_size
=
self
.
kv_cache_spec
.
block_size
num_computed_tokens
=
common_attn_metadata
.
num_computed_tokens_cpu
.
to
(
self
.
device
)
(
block_idx_last_computed_token
,
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
)
=
self
.
_compute_prefix_caching_block_indices
(
common_attn_metadata
,
mamba_block_size
)
else
:
# Always return just a single block per each request:
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
block_idx_last_scheduled_token
=
None
block_idx_last_computed_token
=
None
if
num_prefills
>
0
:
query_start_loc_p
=
(
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
has_initial_states_cpu
=
(
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
has_initial_states_p
=
has_initial_states_cpu
.
to
(
common_attn_metadata
.
query_start_loc
.
device
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
assert
num_computed_tokens
is
not
None
num_computed_tokens_p
=
num_computed_tokens
[
num_reqs
-
num_prefills
:
num_reqs
]
assert
block_idx_first_scheduled_token
is
not
None
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token
[
num_reqs
-
num_prefills
:
num_reqs
]
elif
(
num_decodes
>
0
and
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_decode_tokens
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
self
.
block_idx_last_scheduled_token
[:
num_decodes
].
copy_
(
block_idx_last_scheduled_token
,
non_blocking
=
True
)
block_idx_last_scheduled_token
=
self
.
block_idx_last_scheduled_token
[
:
num_decode_tokens
]
self
.
block_idx_last_computed_token
[:
num_decodes
].
copy_
(
block_idx_last_computed_token
,
non_blocking
=
True
)
block_idx_last_computed_token
=
self
.
block_idx_last_computed_token
[
:
num_decode_tokens
]
return
Mamba1AttentionMetadata
(
query_start_loc_p
=
query_start_loc_p
,
has_initial_states_p
=
has_initial_states_p
,
state_indices_tensor
=
state_indices_tensor
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token
,
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token_p
,
block_idx_last_computed_token
=
block_idx_last_computed_token
,
num_computed_tokens_p
=
num_computed_tokens_p
,
)
metadata_cls
=
Mamba1AttentionMetadata
supports_update_block_table
:
bool
=
False
vllm/v1/attention/backends/mamba2_attn.py
View file @
34916ae3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
itertools
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
replace
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.config
import
VllmConfig
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backends.mamba_attn
import
BaseMambaAttentionMetadataBuilder
from
vllm.v1.attention.backends.mamba_attn
import
(
BaseMambaAttentionMetadata
,
BaseMambaAttentionMetadataBuilder
,
)
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
@@ -94,48 +94,26 @@ class Mamba2AttentionBackend(AttentionBackend):
@
dataclass
class
Mamba2AttentionMetadata
:
num_prefills
:
int
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
query_start_loc_p
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
prep_initial_states
:
bool
chunk_size
:
int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p
:
torch
.
Tensor
|
None
seq_idx_p
:
torch
.
Tensor
|
None
class
Mamba2AttentionMetadata
(
BaseMambaAttentionMetadata
):
prep_initial_states
:
bool
=
False
chunk_size
:
int
=
0
# Chunk-related metadata (only for prefill)
seq_idx_p
:
torch
.
Tensor
|
None
=
None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offests into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p
:
torch
.
Tensor
|
None
cu_chunk_seqlen_p
:
torch
.
Tensor
|
None
=
None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p
:
torch
.
Tensor
|
None
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
block_idx_last_scheduled_token
:
torch
.
Tensor
# shape: [batch,]
block_idx_first_scheduled_token_p
:
torch
.
Tensor
# shape: [batch,]
block_idx_last_computed_token
:
torch
.
Tensor
# shape: [batch,]
num_computed_tokens_p
:
torch
.
Tensor
# shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
dict
|
None
=
None
batch_ptr
:
torch
.
Tensor
|
None
=
None
token_chunk_offset_ptr
:
torch
.
Tensor
|
None
=
None
last_chunk_indices_p
:
torch
.
Tensor
|
None
=
None
class
Mamba2AttentionMetadataBuilder
(
BaseMambaAttentionMetadataBuilder
[
Mamba2AttentionMetadata
]
):
supports_update_block_table
:
bool
=
True
metadata_cls
=
Mamba2AttentionMetadata
def
__init__
(
self
,
...
...
@@ -150,109 +128,31 @@ class Mamba2AttentionMetadataBuilder(
"chunk_size needs to be set in the model config for Mamba2 models"
)
def
build
(
def
_compute_chunk_metadata
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
Mamba2AttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
seq_lens
=
common_attn_metadata
.
seq_lens
query_start_loc_p
=
None
seq_idx_p
=
None
cu_chunk_seqlen_p
=
None
last_chunk_indices_p
=
None
# Need flags to indicate if there are initial states
has_initial_states_p
=
None
prep_initial_states
=
False
# for causal_conv1d
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
num_computed_tokens
,
num_computed_tokens_p
=
None
,
None
block_idx_first_scheduled_token
=
None
block_idx_first_scheduled_token_p
=
None
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
# Additional cache-related varaiables:
mamba_block_size
=
self
.
kv_cache_spec
.
block_size
num_computed_tokens
=
common_attn_metadata
.
num_computed_tokens_cpu
.
to
(
self
.
device
)
(
block_idx_last_computed_token
,
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
)
=
self
.
_compute_prefix_caching_block_indices
(
common_attn_metadata
,
mamba_block_size
)
else
:
# Always return just a single block per each request:
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
# Additional cache-related varaiables:
block_idx_last_scheduled_token
=
None
block_idx_last_computed_token
=
None
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
)
)
# Compute seq_idx for prefill only
if
num_prefills
>
0
:
# [batch,]
has_initial_states_cpu
=
(
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
prep_initial_states
=
torch
.
any
(
has_initial_states_cpu
).
item
()
has_initial_states_p
=
has_initial_states_cpu
.
to
(
common_attn_metadata
.
query_start_loc
.
device
)
query_start_loc_p
=
(
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
assert
num_computed_tokens
is
not
None
num_computed_tokens_p
=
num_computed_tokens
[
num_reqs
-
num_prefills
:
num_reqs
]
assert
block_idx_first_scheduled_token
is
not
None
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token
[
num_reqs
-
num_prefills
:
num_reqs
]
num_computed_tokens_p_cpu
=
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
query_start_loc_p_cpu
=
(
common_attn_metadata
.
query_start_loc_cpu
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
# The code below carefully constructs the chunks such that:
# 1. Chunks contain tokens from a *single* sequence only.
# 2. For every sequence, we are guaranteed that we can
# retrieve the mamba state *every* chunk_size tokens.
# Constraint (1) dramatically simplifies the mamba2 kernels.
# Constraint (2) dramatically simplifies the implementation
# of prefix caching for mamba2 (wip). We need to take care
# of the interaction with chunked prefill in order to
# satisfy constraint (2).
num_prefills
:
int
,
num_computed_tokens_p_cpu
:
torch
.
Tensor
,
query_start_loc_p_cpu
:
torch
.
Tensor
,
)
->
tuple
[
list
[
int
],
list
[
int
],
list
[
int
]]:
"""
Compute chunk-specific metadata for Mamba2.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba2 kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba2 (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen
=
[]
seq_idx
=
[]
last_chunk_indices
=
[]
seqlen_pos
=
0
for
req_idx
in
range
(
num_prefills
):
this_num_computed
=
num_computed_tokens_p_cpu
[
req_idx
].
item
()
this_new_tokens
=
(
...
...
@@ -288,88 +188,68 @@ class Mamba2AttentionMetadataBuilder(
cu_chunk_seqlen
.
append
(
seqlen_pos
)
seq_idx_p
=
torch
.
as_tensor
(
seq_idx
,
device
=
query_start_loc_p
.
device
,
dtype
=
torch
.
int32
)
cu_chunk_seqlen_p
=
torch
.
as_tensor
(
cu_chunk_seqlen
,
device
=
query_start_loc_p
.
device
,
dtype
=
torch
.
int32
)
last_chunk_indices_p
=
torch
.
as_tensor
(
last_chunk_indices
,
device
=
query_start_loc_p
.
device
,
dtype
=
torch
.
int32
)
return
cu_chunk_seqlen
,
seq_idx
,
last_chunk_indices
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
(
compute_causal_conv1d_metadata
(
query_start_loc_p
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
Mamba2AttentionMetadata
:
common
=
self
.
_compute_common_metadata
(
common_attn_metadata
)
seq_idx_p
=
None
cu_chunk_seqlen_p
=
None
last_chunk_indices_p
=
None
prep_initial_states
=
False
# Compute seq_idx for prefill only
if
common
.
num_prefills
>
0
:
prep_initial_states
=
(
torch
.
any
(
common
.
has_initial_states_p
).
item
()
if
common
.
has_initial_states_p
is
not
None
else
False
)
elif
(
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
num_reqs
=
common
.
num_reqs
num_prefills
=
common
.
num_prefills
num_decode_tokens
=
common
.
num_decode_tokens
num_computed_tokens_p_cpu
=
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
query_start_loc_p_cpu
=
(
common_attn_metadata
.
query_start_loc_cpu
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_decode_tokens
]
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
self
.
block_idx_last_scheduled_token
[:
num_decodes
].
copy_
(
block_idx_last_scheduled_token
,
non_blocking
=
True
cu_chunk_seqlen
,
seq_idx
,
last_chunk_indices
=
self
.
_compute_chunk_metadata
(
num_prefills
,
num_computed_tokens_p_cpu
,
query_start_loc_p_cpu
,
)
block_idx_last_scheduled_token
=
self
.
block_idx_last_scheduled_token
[
:
num_decode_tokens
]
self
.
block_idx_last_computed_token
[:
num_decodes
].
copy_
(
block_idx_last_computed_token
,
non_blocking
=
True
seq_idx_p
=
torch
.
as_tensor
(
seq_idx
,
device
=
common_attn_metadata
.
query_start_loc
.
device
,
dtype
=
torch
.
int32
,
)
cu_chunk_seqlen_p
=
torch
.
as_tensor
(
cu_chunk_seqlen
,
device
=
common_attn_metadata
.
query_start_loc
.
device
,
dtype
=
torch
.
int32
,
)
last_chunk_indices_p
=
torch
.
as_tensor
(
last_chunk_indices
,
device
=
common_attn_metadata
.
query_start_loc
.
device
,
dtype
=
torch
.
int32
,
)
block_idx_last_computed_token
=
self
.
block_idx_last_computed_token
[
:
num_decode_tokens
]
attn_metadata
=
Mamba2AttentionMetadata
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
query_start_loc_p
=
query_start_loc_p
,
seq_lens
=
seq_lens
,
return
replace
(
common
,
prep_initial_states
=
prep_initial_states
,
chunk_size
=
self
.
chunk_size
,
has_initial_states_p
=
has_initial_states_p
,
seq_idx_p
=
seq_idx_p
,
state_indices_tensor
=
state_indices_tensor
,
cu_chunk_seqlen_p
=
cu_chunk_seqlen_p
,
last_chunk_indices_p
=
last_chunk_indices_p
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token
,
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token_p
,
block_idx_last_computed_token
=
block_idx_last_computed_token
,
num_computed_tokens_p
=
num_computed_tokens_p
,
)
return
attn_metadata
def
update_block_table
(
self
,
metadata
:
Mamba2AttentionMetadata
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
Mamba2AttentionMetadata
:
new_metadata
=
copy
.
copy
(
metadata
)
prefix_caching
=
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
state_indices_t
=
blk_table
if
prefix_caching
else
blk_table
[:,
0
]
num_reqs
=
blk_table
.
shape
[
0
]
# For CUDA graphs, copy to persistent buffer
if
(
metadata
.
num_prefills
==
0
and
num_reqs
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
persistent_state_indices_t
=
self
.
state_indices_tensor
[:
num_reqs
]
persistent_state_indices_t
.
copy_
(
state_indices_t
,
non_blocking
=
True
)
state_indices_t
=
persistent_state_indices_t
new_metadata
.
state_indices_tensor
=
state_indices_t
return
new_metadata
vllm/v1/attention/backends/mamba_attn.py
View file @
34916ae3
...
...
@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
abc
import
copy
from
dataclasses
import
dataclass
from
typing
import
ClassVar
,
TypeVar
import
torch
...
...
@@ -9,20 +11,52 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
(
PAD_SLOT_ID
,
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
M
=
TypeVar
(
"M"
)
M
=
TypeVar
(
"M"
,
bound
=
"BaseMambaAttentionMetadata"
)
@
dataclass
class
BaseMambaAttentionMetadata
:
num_prefills
:
int
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
num_reqs
:
int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p
:
torch
.
Tensor
|
None
query_start_loc_p
:
torch
.
Tensor
|
None
num_computed_tokens_p
:
torch
.
Tensor
|
None
state_indices_tensor
:
torch
.
Tensor
# The following tensors are only used for prefix caching and are None if disabled
block_idx_last_scheduled_token
:
torch
.
Tensor
|
None
block_idx_first_scheduled_token_p
:
torch
.
Tensor
|
None
block_idx_last_computed_token
:
torch
.
Tensor
|
None
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
dict
|
None
=
None
batch_ptr
:
torch
.
Tensor
|
None
=
None
token_chunk_offset_ptr
:
torch
.
Tensor
|
None
=
None
class
BaseMambaAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
M
],
abc
.
ABC
):
metadata_cls
:
type
[
M
]
reorder_batch_threshold
:
int
=
1
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
(
AttentionCGSupport
.
UNIFORM_SINGLE_TOKEN_DECODE
)
supports_update_block_table
:
bool
=
True
def
__init__
(
self
,
...
...
@@ -87,6 +121,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
return
self
.
build
(
0
,
m
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
M
:
"""
Default build implementation for Mamba-like attention backends.
Subclasses (e.g., Mamba2) can override to add additional metadata.
"""
return
self
.
_compute_common_metadata
(
common_attn_metadata
)
def
_compute_prefix_caching_block_indices
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
...
...
@@ -115,3 +161,147 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
)
def
_compute_common_metadata
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
)
->
M
:
"""
Compute metadata common to both Mamba1 and Mamba2.
"""
num_reqs
=
common_attn_metadata
.
num_reqs
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
)
)
# Need flags to indicate if there are initial states
has_initial_states_p
=
None
query_start_loc_p
=
None
num_computed_tokens
=
None
num_computed_tokens_p
=
None
# for prefix caching
block_idx_first_scheduled_token
=
None
block_idx_first_scheduled_token_p
=
None
block_idx_last_computed_token
=
None
block_idx_last_scheduled_token
=
None
# for causal_conv1d
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
# Additional cache-related varaiables:
mamba_block_size
=
self
.
kv_cache_spec
.
block_size
num_computed_tokens
=
common_attn_metadata
.
num_computed_tokens_cpu
.
to
(
self
.
device
)
(
block_idx_last_computed_token
,
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
)
=
self
.
_compute_prefix_caching_block_indices
(
common_attn_metadata
,
mamba_block_size
)
else
:
# Always return just a single block per each request:
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
if
num_prefills
>
0
:
query_start_loc_p
=
(
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
has_initial_states_cpu
=
(
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
has_initial_states_p
=
has_initial_states_cpu
.
to
(
common_attn_metadata
.
query_start_loc
.
device
)
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
(
compute_causal_conv1d_metadata
(
query_start_loc_p
)
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
assert
num_computed_tokens
is
not
None
num_computed_tokens_p
=
num_computed_tokens
[
num_reqs
-
num_prefills
:
num_reqs
]
assert
block_idx_first_scheduled_token
is
not
None
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token
[
num_reqs
-
num_prefills
:
num_reqs
]
elif
(
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_decode_tokens
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
self
.
block_idx_last_scheduled_token
[:
num_decodes
].
copy_
(
block_idx_last_scheduled_token
,
non_blocking
=
True
)
block_idx_last_scheduled_token
=
self
.
block_idx_last_scheduled_token
[
:
num_decode_tokens
]
self
.
block_idx_last_computed_token
[:
num_decodes
].
copy_
(
block_idx_last_computed_token
,
non_blocking
=
True
)
block_idx_last_computed_token
=
self
.
block_idx_last_computed_token
[
:
num_decode_tokens
]
return
self
.
metadata_cls
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
query_start_loc_p
=
query_start_loc_p
,
has_initial_states_p
=
has_initial_states_p
,
state_indices_tensor
=
state_indices_tensor
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token
,
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token_p
,
block_idx_last_computed_token
=
block_idx_last_computed_token
,
num_computed_tokens_p
=
num_computed_tokens_p
,
num_reqs
=
num_reqs
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
)
def
update_block_table
(
self
,
metadata
:
M
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
M
:
new_metadata
=
copy
.
copy
(
metadata
)
prefix_caching
=
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
state_indices_t
=
blk_table
if
prefix_caching
else
blk_table
[:,
0
]
num_reqs
=
blk_table
.
shape
[
0
]
# For CUDA graphs, copy to persistent buffer
if
(
metadata
.
num_prefills
==
0
and
num_reqs
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
persistent_state_indices_t
=
self
.
state_indices_tensor
[:
num_reqs
]
persistent_state_indices_t
.
copy_
(
state_indices_t
,
non_blocking
=
True
)
state_indices_t
=
persistent_state_indices_t
new_metadata
.
state_indices_tensor
=
state_indices_t
return
new_metadata
vllm/v1/attention/backends/short_conv_attn.py
View file @
34916ae3
...
...
@@ -2,15 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.v1.attention.backends.mamba_attn
import
BaseMambaAttentionMetadataBuilder
from
vllm.v1.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
,
from
vllm.v1.attention.backends.mamba_attn
import
(
BaseMambaAttentionMetadata
,
BaseMambaAttentionMetadataBuilder
,
)
...
...
@@ -21,84 +16,11 @@ class ShortConvAttentionBackend(AttentionBackend):
@
dataclass
class
ShortConvAttentionMetadata
:
num_prefills
:
int
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
query_start_loc
:
torch
.
Tensor
state_indices_tensor
:
torch
.
Tensor
has_initial_states_p
:
torch
.
Tensor
|
None
# For causal_conv1d
nums_dict
:
dict
|
None
=
None
batch_ptr
:
torch
.
Tensor
|
None
=
None
token_chunk_offset_ptr
:
torch
.
Tensor
|
None
=
None
class
ShortConvAttentionMetadata
(
BaseMambaAttentionMetadata
):
pass
class
ShortConvAttentionMetadataBuilder
(
BaseMambaAttentionMetadataBuilder
[
ShortConvAttentionMetadata
]
):
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
ShortConvAttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
query_start_loc
=
common_attn_metadata
.
query_start_loc
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
# for causal_conv1d
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
)
)
has_initial_states_p
=
None
if
num_prefills
>
0
:
has_initial_states_cpu
=
(
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
has_initial_states_p
=
has_initial_states_cpu
.
to
(
query_start_loc
.
device
)
query_start_loc_p
=
(
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
(
compute_causal_conv1d_metadata
(
query_start_loc_p
)
)
elif
(
num_decodes
>
0
and
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_decode_tokens
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
attn_metadata
=
ShortConvAttentionMetadata
(
query_start_loc
=
query_start_loc
,
state_indices_tensor
=
state_indices_tensor
,
has_initial_states_p
=
has_initial_states_p
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
)
return
attn_metadata
metadata_cls
=
ShortConvAttentionMetadata
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment