Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
34916ae3
"vscode:/vscode.git/clone" did not exist on "87ef4618428fe2c8f756a80c271857fa6ae2623a"
Unverified
Commit
34916ae3
authored
Dec 23, 2025
by
Asaf Joseph Gardin
Committed by
GitHub
Dec 23, 2025
Browse files
[Mamba] - Consolidate Mambas Attention Logic (#28133)
parent
0736f901
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
305 additions
and
448 deletions
+305
-448
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+1
-5
vllm/v1/attention/backends/mamba1_attn.py
vllm/v1/attention/backends/mamba1_attn.py
+7
-138
vllm/v1/attention/backends/mamba2_attn.py
vllm/v1/attention/backends/mamba2_attn.py
+100
-220
vllm/v1/attention/backends/mamba_attn.py
vllm/v1/attention/backends/mamba_attn.py
+191
-1
vllm/v1/attention/backends/short_conv_attn.py
vllm/v1/attention/backends/short_conv_attn.py
+6
-84
No files found.
vllm/model_executor/layers/mamba/short_conv.py
View file @
34916ae3
...
@@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp):
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
BCx
,
_
=
self
.
in_proj
(
hidden_states
)
BCx
,
_
=
self
.
in_proj
(
hidden_states
)
...
@@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp):
[
num_decodes
,
num_prefills
],
[
num_decodes
,
num_prefills
],
dim
=
0
,
dim
=
0
,
)
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
conv_output_list
=
[]
conv_output_list
=
[]
...
...
vllm/v1/attention/backends/mamba1_attn.py
View file @
34916ae3
...
@@ -3,17 +3,11 @@
...
@@ -3,17 +3,11 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.v1.attention.backends.mamba_attn
import
(
from
vllm.config
import
VllmConfig
BaseMambaAttentionMetadata
,
from
vllm.v1.attention.backends.mamba_attn
import
BaseMambaAttentionMetadataBuilder
BaseMambaAttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
split_decodes_and_prefills
,
)
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
class
Mamba1AttentionBackend
(
AttentionBackend
):
class
Mamba1AttentionBackend
(
AttentionBackend
):
...
@@ -23,137 +17,12 @@ class Mamba1AttentionBackend(AttentionBackend):
...
@@ -23,137 +17,12 @@ class Mamba1AttentionBackend(AttentionBackend):
@
dataclass
@
dataclass
class
Mamba1AttentionMetadata
:
class
Mamba1AttentionMetadata
(
BaseMambaAttentionMetadata
):
query_start_loc_p
:
torch
.
Tensor
pass
state_indices_tensor
:
torch
.
Tensor
has_initial_states_p
:
torch
.
Tensor
|
None
num_prefills
:
int
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
block_idx_last_scheduled_token
:
torch
.
Tensor
# shape: [batch,]
block_idx_first_scheduled_token_p
:
torch
.
Tensor
# shape: [batch,]
block_idx_last_computed_token
:
torch
.
Tensor
# shape: [batch,]
num_computed_tokens_p
:
torch
.
Tensor
# shape: [batch,]
class
Mamba1AttentionMetadataBuilder
(
class
Mamba1AttentionMetadataBuilder
(
BaseMambaAttentionMetadataBuilder
[
Mamba1AttentionMetadata
]
BaseMambaAttentionMetadataBuilder
[
Mamba1AttentionMetadata
]
):
):
def
__init__
(
metadata_cls
=
Mamba1AttentionMetadata
self
,
supports_update_block_table
:
bool
=
False
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
assert
isinstance
(
kv_cache_spec
,
MambaSpec
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
Mamba1AttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
)
)
has_initial_states_p
=
None
query_start_loc_p
=
None
num_computed_tokens
,
num_computed_tokens_p
=
None
,
None
block_idx_first_scheduled_token
=
None
block_idx_first_scheduled_token_p
=
None
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
# We should consolidate this code
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
mamba_block_size
=
self
.
kv_cache_spec
.
block_size
num_computed_tokens
=
common_attn_metadata
.
num_computed_tokens_cpu
.
to
(
self
.
device
)
(
block_idx_last_computed_token
,
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
)
=
self
.
_compute_prefix_caching_block_indices
(
common_attn_metadata
,
mamba_block_size
)
else
:
# Always return just a single block per each request:
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
block_idx_last_scheduled_token
=
None
block_idx_last_computed_token
=
None
if
num_prefills
>
0
:
query_start_loc_p
=
(
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
has_initial_states_cpu
=
(
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
has_initial_states_p
=
has_initial_states_cpu
.
to
(
common_attn_metadata
.
query_start_loc
.
device
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
assert
num_computed_tokens
is
not
None
num_computed_tokens_p
=
num_computed_tokens
[
num_reqs
-
num_prefills
:
num_reqs
]
assert
block_idx_first_scheduled_token
is
not
None
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token
[
num_reqs
-
num_prefills
:
num_reqs
]
elif
(
num_decodes
>
0
and
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_decode_tokens
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
self
.
block_idx_last_scheduled_token
[:
num_decodes
].
copy_
(
block_idx_last_scheduled_token
,
non_blocking
=
True
)
block_idx_last_scheduled_token
=
self
.
block_idx_last_scheduled_token
[
:
num_decode_tokens
]
self
.
block_idx_last_computed_token
[:
num_decodes
].
copy_
(
block_idx_last_computed_token
,
non_blocking
=
True
)
block_idx_last_computed_token
=
self
.
block_idx_last_computed_token
[
:
num_decode_tokens
]
return
Mamba1AttentionMetadata
(
query_start_loc_p
=
query_start_loc_p
,
has_initial_states_p
=
has_initial_states_p
,
state_indices_tensor
=
state_indices_tensor
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token
,
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token_p
,
block_idx_last_computed_token
=
block_idx_last_computed_token
,
num_computed_tokens_p
=
num_computed_tokens_p
,
)
vllm/v1/attention/backends/mamba2_attn.py
View file @
34916ae3
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
itertools
import
itertools
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
replace
import
torch
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backends.mamba_attn
import
BaseMambaAttentionMetadataBuilder
from
vllm.v1.attention.backends.mamba_attn
import
(
BaseMambaAttentionMetadata
,
BaseMambaAttentionMetadataBuilder
,
)
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
,
)
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
@@ -94,48 +94,26 @@ class Mamba2AttentionBackend(AttentionBackend):
...
@@ -94,48 +94,26 @@ class Mamba2AttentionBackend(AttentionBackend):
@
dataclass
@
dataclass
class
Mamba2AttentionMetadata
:
class
Mamba2AttentionMetadata
(
BaseMambaAttentionMetadata
):
num_prefills
:
int
prep_initial_states
:
bool
=
False
num_prefill_tokens
:
int
chunk_size
:
int
=
0
num_decodes
:
int
num_decode_tokens
:
int
query_start_loc_p
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
prep_initial_states
:
bool
chunk_size
:
int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p
:
torch
.
Tensor
|
None
seq_idx_p
:
torch
.
Tensor
|
None
# Chunk-related metadata (only for prefill)
seq_idx_p
:
torch
.
Tensor
|
None
=
None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offests into the varlen sequence dimension. It is defined
# each chunk, its offests into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p
:
torch
.
Tensor
|
None
cu_chunk_seqlen_p
:
torch
.
Tensor
|
None
=
None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p
:
torch
.
Tensor
|
None
last_chunk_indices_p
:
torch
.
Tensor
|
None
=
None
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
block_idx_last_scheduled_token
:
torch
.
Tensor
# shape: [batch,]
block_idx_first_scheduled_token_p
:
torch
.
Tensor
# shape: [batch,]
block_idx_last_computed_token
:
torch
.
Tensor
# shape: [batch,]
num_computed_tokens_p
:
torch
.
Tensor
# shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
dict
|
None
=
None
batch_ptr
:
torch
.
Tensor
|
None
=
None
token_chunk_offset_ptr
:
torch
.
Tensor
|
None
=
None
class
Mamba2AttentionMetadataBuilder
(
class
Mamba2AttentionMetadataBuilder
(
BaseMambaAttentionMetadataBuilder
[
Mamba2AttentionMetadata
]
BaseMambaAttentionMetadataBuilder
[
Mamba2AttentionMetadata
]
):
):
supports_update_block_table
:
bool
=
True
metadata_cls
=
Mamba2AttentionMetadata
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -150,109 +128,31 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -150,109 +128,31 @@ class Mamba2AttentionMetadataBuilder(
"chunk_size needs to be set in the model config for Mamba2 models"
"chunk_size needs to be set in the model config for Mamba2 models"
)
)
def
build
(
def
_compute_chunk_metadata
(
self
,
self
,
common_prefix_len
:
int
,
num_prefills
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
num_computed_tokens_p_cpu
:
torch
.
Tensor
,
fast_build
:
bool
=
False
,
query_start_loc_p_cpu
:
torch
.
Tensor
,
)
->
Mamba2AttentionMetadata
:
)
->
tuple
[
list
[
int
],
list
[
int
],
list
[
int
]]:
num_reqs
=
common_attn_metadata
.
num_reqs
"""
seq_lens
=
common_attn_metadata
.
seq_lens
Compute chunk-specific metadata for Mamba2.
query_start_loc_p
=
None
The code below carefully constructs the chunks such that:
seq_idx_p
=
None
1. Chunks contain tokens from a *single* sequence only.
cu_chunk_seqlen_p
=
None
2. For every sequence, we are guaranteed that we can
last_chunk_indices_p
=
None
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba2 kernels.
# Need flags to indicate if there are initial states
Constraint (2) dramatically simplifies the implementation
has_initial_states_p
=
None
of prefix caching for mamba2 (wip). We need to take care
prep_initial_states
=
False
of the interaction with chunked prefill in order to
satisfy constraint (2).
# for causal_conv1d
"""
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
num_computed_tokens
,
num_computed_tokens_p
=
None
,
None
block_idx_first_scheduled_token
=
None
block_idx_first_scheduled_token_p
=
None
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
# Additional cache-related varaiables:
mamba_block_size
=
self
.
kv_cache_spec
.
block_size
num_computed_tokens
=
common_attn_metadata
.
num_computed_tokens_cpu
.
to
(
self
.
device
)
(
block_idx_last_computed_token
,
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
)
=
self
.
_compute_prefix_caching_block_indices
(
common_attn_metadata
,
mamba_block_size
)
else
:
# Always return just a single block per each request:
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
# Additional cache-related varaiables:
block_idx_last_scheduled_token
=
None
block_idx_last_computed_token
=
None
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
)
)
# Compute seq_idx for prefill only
if
num_prefills
>
0
:
# [batch,]
has_initial_states_cpu
=
(
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
prep_initial_states
=
torch
.
any
(
has_initial_states_cpu
).
item
()
has_initial_states_p
=
has_initial_states_cpu
.
to
(
common_attn_metadata
.
query_start_loc
.
device
)
query_start_loc_p
=
(
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
assert
num_computed_tokens
is
not
None
num_computed_tokens_p
=
num_computed_tokens
[
num_reqs
-
num_prefills
:
num_reqs
]
assert
block_idx_first_scheduled_token
is
not
None
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token
[
num_reqs
-
num_prefills
:
num_reqs
]
num_computed_tokens_p_cpu
=
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
query_start_loc_p_cpu
=
(
common_attn_metadata
.
query_start_loc_cpu
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
# The code below carefully constructs the chunks such that:
# 1. Chunks contain tokens from a *single* sequence only.
# 2. For every sequence, we are guaranteed that we can
# retrieve the mamba state *every* chunk_size tokens.
# Constraint (1) dramatically simplifies the mamba2 kernels.
# Constraint (2) dramatically simplifies the implementation
# of prefix caching for mamba2 (wip). We need to take care
# of the interaction with chunked prefill in order to
# satisfy constraint (2).
# TODO (tdoublep): This code could probably be optimized.
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen
=
[]
cu_chunk_seqlen
=
[]
seq_idx
=
[]
seq_idx
=
[]
last_chunk_indices
=
[]
last_chunk_indices
=
[]
seqlen_pos
=
0
seqlen_pos
=
0
for
req_idx
in
range
(
num_prefills
):
for
req_idx
in
range
(
num_prefills
):
this_num_computed
=
num_computed_tokens_p_cpu
[
req_idx
].
item
()
this_num_computed
=
num_computed_tokens_p_cpu
[
req_idx
].
item
()
this_new_tokens
=
(
this_new_tokens
=
(
...
@@ -288,88 +188,68 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -288,88 +188,68 @@ class Mamba2AttentionMetadataBuilder(
cu_chunk_seqlen
.
append
(
seqlen_pos
)
cu_chunk_seqlen
.
append
(
seqlen_pos
)
seq_idx_p
=
torch
.
as_tensor
(
return
cu_chunk_seqlen
,
seq_idx
,
last_chunk_indices
seq_idx
,
device
=
query_start_loc_p
.
device
,
dtype
=
torch
.
int32
)
cu_chunk_seqlen_p
=
torch
.
as_tensor
(
cu_chunk_seqlen
,
device
=
query_start_loc_p
.
device
,
dtype
=
torch
.
int32
)
last_chunk_indices_p
=
torch
.
as_tensor
(
last_chunk_indices
,
device
=
query_start_loc_p
.
device
,
dtype
=
torch
.
int32
)
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
(
def
build
(
compute_causal_conv1d_metadata
(
query_start_loc_p
)
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
Mamba2AttentionMetadata
:
common
=
self
.
_compute_common_metadata
(
common_attn_metadata
)
seq_idx_p
=
None
cu_chunk_seqlen_p
=
None
last_chunk_indices_p
=
None
prep_initial_states
=
False
# Compute seq_idx for prefill only
if
common
.
num_prefills
>
0
:
prep_initial_states
=
(
torch
.
any
(
common
.
has_initial_states_p
).
item
()
if
common
.
has_initial_states_p
is
not
None
else
False
)
)
elif
(
num_reqs
=
common
.
num_reqs
num_decodes
<=
self
.
decode_cudagraph_max_bs
num_prefills
=
common
.
num_prefills
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
num_decode_tokens
=
common
.
num_decode_tokens
):
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
num_computed_tokens_p_cpu
=
common_attn_metadata
.
num_computed_tokens_cpu
[
state_indices_tensor
,
non_blocking
=
True
num_reqs
-
num_prefills
:
num_reqs
]
query_start_loc_p_cpu
=
(
common_attn_metadata
.
query_start_loc_cpu
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_decode_tokens
]
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
cu_chunk_seqlen
,
seq_idx
,
last_chunk_indices
=
self
.
_compute_chunk_metadata
(
self
.
block_idx_last_scheduled_token
[:
num_decodes
].
copy_
(
num_prefills
,
block_idx_last_scheduled_token
,
non_blocking
=
True
num_computed_tokens_p_cpu
,
query_start_loc_p_cpu
,
)
)
block_idx_last_scheduled_token
=
self
.
block_idx_last_scheduled_token
[
:
num_decode_tokens
]
self
.
block_idx_last_computed_token
[:
num_decodes
].
copy_
(
seq_idx_p
=
torch
.
as_tensor
(
block_idx_last_computed_token
,
non_blocking
=
True
seq_idx
,
device
=
common_attn_metadata
.
query_start_loc
.
device
,
dtype
=
torch
.
int32
,
)
cu_chunk_seqlen_p
=
torch
.
as_tensor
(
cu_chunk_seqlen
,
device
=
common_attn_metadata
.
query_start_loc
.
device
,
dtype
=
torch
.
int32
,
)
last_chunk_indices_p
=
torch
.
as_tensor
(
last_chunk_indices
,
device
=
common_attn_metadata
.
query_start_loc
.
device
,
dtype
=
torch
.
int32
,
)
)
block_idx_last_computed_token
=
self
.
block_idx_last_computed_token
[
:
num_decode_tokens
]
attn_metadata
=
Mamba2AttentionMetadata
(
return
replace
(
num_prefills
=
num_prefills
,
common
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
query_start_loc_p
=
query_start_loc_p
,
seq_lens
=
seq_lens
,
prep_initial_states
=
prep_initial_states
,
prep_initial_states
=
prep_initial_states
,
chunk_size
=
self
.
chunk_size
,
chunk_size
=
self
.
chunk_size
,
has_initial_states_p
=
has_initial_states_p
,
seq_idx_p
=
seq_idx_p
,
seq_idx_p
=
seq_idx_p
,
state_indices_tensor
=
state_indices_tensor
,
cu_chunk_seqlen_p
=
cu_chunk_seqlen_p
,
cu_chunk_seqlen_p
=
cu_chunk_seqlen_p
,
last_chunk_indices_p
=
last_chunk_indices_p
,
last_chunk_indices_p
=
last_chunk_indices_p
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token
,
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token_p
,
block_idx_last_computed_token
=
block_idx_last_computed_token
,
num_computed_tokens_p
=
num_computed_tokens_p
,
)
)
return
attn_metadata
def
update_block_table
(
self
,
metadata
:
Mamba2AttentionMetadata
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
Mamba2AttentionMetadata
:
new_metadata
=
copy
.
copy
(
metadata
)
prefix_caching
=
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
state_indices_t
=
blk_table
if
prefix_caching
else
blk_table
[:,
0
]
num_reqs
=
blk_table
.
shape
[
0
]
# For CUDA graphs, copy to persistent buffer
if
(
metadata
.
num_prefills
==
0
and
num_reqs
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
persistent_state_indices_t
=
self
.
state_indices_tensor
[:
num_reqs
]
persistent_state_indices_t
.
copy_
(
state_indices_t
,
non_blocking
=
True
)
state_indices_t
=
persistent_state_indices_t
new_metadata
.
state_indices_tensor
=
state_indices_t
return
new_metadata
vllm/v1/attention/backends/mamba_attn.py
View file @
34916ae3
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
abc
import
abc
import
copy
from
dataclasses
import
dataclass
from
typing
import
ClassVar
,
TypeVar
from
typing
import
ClassVar
,
TypeVar
import
torch
import
torch
...
@@ -9,20 +11,52 @@ import torch
...
@@ -9,20 +11,52 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
PAD_SLOT_ID
,
AttentionCGSupport
,
AttentionCGSupport
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
,
)
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
M
=
TypeVar
(
"M"
)
M
=
TypeVar
(
"M"
,
bound
=
"BaseMambaAttentionMetadata"
)
@
dataclass
class
BaseMambaAttentionMetadata
:
num_prefills
:
int
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
num_reqs
:
int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p
:
torch
.
Tensor
|
None
query_start_loc_p
:
torch
.
Tensor
|
None
num_computed_tokens_p
:
torch
.
Tensor
|
None
state_indices_tensor
:
torch
.
Tensor
# The following tensors are only used for prefix caching and are None if disabled
block_idx_last_scheduled_token
:
torch
.
Tensor
|
None
block_idx_first_scheduled_token_p
:
torch
.
Tensor
|
None
block_idx_last_computed_token
:
torch
.
Tensor
|
None
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
dict
|
None
=
None
batch_ptr
:
torch
.
Tensor
|
None
=
None
token_chunk_offset_ptr
:
torch
.
Tensor
|
None
=
None
class
BaseMambaAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
M
],
abc
.
ABC
):
class
BaseMambaAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
M
],
abc
.
ABC
):
metadata_cls
:
type
[
M
]
reorder_batch_threshold
:
int
=
1
reorder_batch_threshold
:
int
=
1
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
(
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
(
AttentionCGSupport
.
UNIFORM_SINGLE_TOKEN_DECODE
AttentionCGSupport
.
UNIFORM_SINGLE_TOKEN_DECODE
)
)
supports_update_block_table
:
bool
=
True
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -87,6 +121,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
...
@@ -87,6 +121,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
return
self
.
build
(
0
,
m
)
return
self
.
build
(
0
,
m
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
M
:
"""
Default build implementation for Mamba-like attention backends.
Subclasses (e.g., Mamba2) can override to add additional metadata.
"""
return
self
.
_compute_common_metadata
(
common_attn_metadata
)
def
_compute_prefix_caching_block_indices
(
def
_compute_prefix_caching_block_indices
(
self
,
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
...
@@ -115,3 +161,147 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
...
@@ -115,3 +161,147 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
block_idx_first_scheduled_token
,
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
block_idx_last_scheduled_token
,
)
)
def
_compute_common_metadata
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
)
->
M
:
"""
Compute metadata common to both Mamba1 and Mamba2.
"""
num_reqs
=
common_attn_metadata
.
num_reqs
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
)
)
# Need flags to indicate if there are initial states
has_initial_states_p
=
None
query_start_loc_p
=
None
num_computed_tokens
=
None
num_computed_tokens_p
=
None
# for prefix caching
block_idx_first_scheduled_token
=
None
block_idx_first_scheduled_token_p
=
None
block_idx_last_computed_token
=
None
block_idx_last_scheduled_token
=
None
# for causal_conv1d
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
# Additional cache-related varaiables:
mamba_block_size
=
self
.
kv_cache_spec
.
block_size
num_computed_tokens
=
common_attn_metadata
.
num_computed_tokens_cpu
.
to
(
self
.
device
)
(
block_idx_last_computed_token
,
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
)
=
self
.
_compute_prefix_caching_block_indices
(
common_attn_metadata
,
mamba_block_size
)
else
:
# Always return just a single block per each request:
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
if
num_prefills
>
0
:
query_start_loc_p
=
(
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
has_initial_states_cpu
=
(
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
has_initial_states_p
=
has_initial_states_cpu
.
to
(
common_attn_metadata
.
query_start_loc
.
device
)
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
(
compute_causal_conv1d_metadata
(
query_start_loc_p
)
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
assert
num_computed_tokens
is
not
None
num_computed_tokens_p
=
num_computed_tokens
[
num_reqs
-
num_prefills
:
num_reqs
]
assert
block_idx_first_scheduled_token
is
not
None
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token
[
num_reqs
-
num_prefills
:
num_reqs
]
elif
(
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_decode_tokens
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
self
.
block_idx_last_scheduled_token
[:
num_decodes
].
copy_
(
block_idx_last_scheduled_token
,
non_blocking
=
True
)
block_idx_last_scheduled_token
=
self
.
block_idx_last_scheduled_token
[
:
num_decode_tokens
]
self
.
block_idx_last_computed_token
[:
num_decodes
].
copy_
(
block_idx_last_computed_token
,
non_blocking
=
True
)
block_idx_last_computed_token
=
self
.
block_idx_last_computed_token
[
:
num_decode_tokens
]
return
self
.
metadata_cls
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
query_start_loc_p
=
query_start_loc_p
,
has_initial_states_p
=
has_initial_states_p
,
state_indices_tensor
=
state_indices_tensor
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token
,
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token_p
,
block_idx_last_computed_token
=
block_idx_last_computed_token
,
num_computed_tokens_p
=
num_computed_tokens_p
,
num_reqs
=
num_reqs
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
)
def
update_block_table
(
self
,
metadata
:
M
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
M
:
new_metadata
=
copy
.
copy
(
metadata
)
prefix_caching
=
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
state_indices_t
=
blk_table
if
prefix_caching
else
blk_table
[:,
0
]
num_reqs
=
blk_table
.
shape
[
0
]
# For CUDA graphs, copy to persistent buffer
if
(
metadata
.
num_prefills
==
0
and
num_reqs
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
persistent_state_indices_t
=
self
.
state_indices_tensor
[:
num_reqs
]
persistent_state_indices_t
.
copy_
(
state_indices_t
,
non_blocking
=
True
)
state_indices_t
=
persistent_state_indices_t
new_metadata
.
state_indices_tensor
=
state_indices_t
return
new_metadata
vllm/v1/attention/backends/short_conv_attn.py
View file @
34916ae3
...
@@ -2,15 +2,10 @@
...
@@ -2,15 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.v1.attention.backends.mamba_attn
import
BaseMambaAttentionMetadataBuilder
from
vllm.v1.attention.backends.mamba_attn
import
(
from
vllm.v1.attention.backends.utils
import
(
BaseMambaAttentionMetadata
,
PAD_SLOT_ID
,
BaseMambaAttentionMetadataBuilder
,
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
,
)
)
...
@@ -21,84 +16,11 @@ class ShortConvAttentionBackend(AttentionBackend):
...
@@ -21,84 +16,11 @@ class ShortConvAttentionBackend(AttentionBackend):
@
dataclass
@
dataclass
class
ShortConvAttentionMetadata
:
class
ShortConvAttentionMetadata
(
BaseMambaAttentionMetadata
):
num_prefills
:
int
pass
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
query_start_loc
:
torch
.
Tensor
state_indices_tensor
:
torch
.
Tensor
has_initial_states_p
:
torch
.
Tensor
|
None
# For causal_conv1d
nums_dict
:
dict
|
None
=
None
batch_ptr
:
torch
.
Tensor
|
None
=
None
token_chunk_offset_ptr
:
torch
.
Tensor
|
None
=
None
class
ShortConvAttentionMetadataBuilder
(
class
ShortConvAttentionMetadataBuilder
(
BaseMambaAttentionMetadataBuilder
[
ShortConvAttentionMetadata
]
BaseMambaAttentionMetadataBuilder
[
ShortConvAttentionMetadata
]
):
):
def
build
(
metadata_cls
=
ShortConvAttentionMetadata
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
ShortConvAttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
query_start_loc
=
common_attn_metadata
.
query_start_loc
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
# for causal_conv1d
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
)
)
has_initial_states_p
=
None
if
num_prefills
>
0
:
has_initial_states_cpu
=
(
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
has_initial_states_p
=
has_initial_states_cpu
.
to
(
query_start_loc
.
device
)
query_start_loc_p
=
(
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
(
compute_causal_conv1d_metadata
(
query_start_loc_p
)
)
elif
(
num_decodes
>
0
and
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_decode_tokens
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
attn_metadata
=
ShortConvAttentionMetadata
(
query_start_loc
=
query_start_loc
,
state_indices_tensor
=
state_indices_tensor
,
has_initial_states_p
=
has_initial_states_p
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
)
return
attn_metadata
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment