Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fceafaf5
Unverified
Commit
fceafaf5
authored
Aug 13, 2025
by
Chen Zhang
Committed by
GitHub
Aug 13, 2025
Browse files
[Bugfix][mamba] Fix type annotation of Mamba2Metadata (#22787)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
6b794c75
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
21 deletions
+26
-21
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+4
-4
vllm/v1/attention/backends/mamba_attn.py
vllm/v1/attention/backends/mamba_attn.py
+22
-17
No files found.
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
fceafaf5
...
...
@@ -473,12 +473,12 @@ class MambaMixer2(MambaBase, CustomOp):
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states
has_initial_states_p
=
attn_metadata
.
has_initial_states
_p
prep_initial_states
=
attn_metadata
.
prep_initial_states
chunk_size
=
attn_metadata
.
chunk_size
seq_idx_p
=
attn_metadata
.
seq_idx
chunk_indices_p
=
attn_metadata
.
chunk_indices
chunk_offsets_p
=
attn_metadata
.
chunk_offsets
seq_idx_p
=
attn_metadata
.
seq_idx
_p
chunk_indices_p
=
attn_metadata
.
chunk_indices
_p
chunk_offsets_p
=
attn_metadata
.
chunk_offsets
_p
else
:
conv_state
=
mamba_cache_params
.
conv_state
ssm_state
=
mamba_cache_params
.
ssm_state
...
...
vllm/v1/attention/backends/mamba_attn.py
View file @
fceafaf5
...
...
@@ -68,14 +68,19 @@ class Mamba2AttentionMetadata:
query_start_loc
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
has_initial_states
:
torch
.
Tensor
prep_initial_states
:
bool
chunk_size
:
int
seq_idx
:
torch
.
Tensor
chunk_indices
:
torch
.
Tensor
chunk_offsets
:
torch
.
Tensor
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p
:
Optional
[
torch
.
Tensor
]
seq_idx_p
:
Optional
[
torch
.
Tensor
]
chunk_indices_p
:
Optional
[
torch
.
Tensor
]
chunk_offsets_p
:
Optional
[
torch
.
Tensor
]
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
tensor
]
=
None
...
...
@@ -115,11 +120,11 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_idx
=
None
chunk_indices
,
chunk_offsets
=
None
,
None
seq_idx
_p
=
None
chunk_indices
_p
,
chunk_offsets
_p
=
None
,
None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states
=
None
has_initial_states
_p
=
None
prep_initial_states
=
False
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
...
...
@@ -135,25 +140,25 @@ class Mamba2AttentionMetadataBuilder(
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
prep_initial_states
=
torch
.
any
(
has_initial_states_cpu
).
item
()
has_initial_states
=
has_initial_states_cpu
.
to
(
has_initial_states
_p
=
has_initial_states_cpu
.
to
(
query_start_loc
.
device
)
query_start_loc_p
=
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
seq_idx
=
torch
.
repeat_interleave
(
torch
.
arange
(
seq_idx
_p
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_prefills
,
dtype
=
torch
.
int32
,
device
=
query_start_loc_p
.
device
),
query_start_loc_p
.
diff
(),
output_size
=
num_prefill_tokens
)
seq_idx
.
unsqueeze_
(
0
)
seq_idx
_p
.
unsqueeze_
(
0
)
# We compute metadata for chunked prefill once at the top level
# model forward and reuse them in mamba layers. If not needed,
# they will be ignored inside mamba kernels.
if
prep_initial_states
:
chunk_indices
,
chunk_offsets
=
(
chunk_indices
_p
,
chunk_offsets
_p
=
(
_query_start_loc_to_chunk_indices_offsets
(
query_start_loc_p
,
self
.
chunk_size
,
num_prefill_tokens
))
...
...
@@ -173,12 +178,12 @@ class Mamba2AttentionMetadataBuilder(
num_decode_tokens
=
num_decode_tokens
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
has_initial_states
=
has_initial_states
,
prep_initial_states
=
prep_initial_states
,
chunk_size
=
self
.
chunk_size
,
seq_idx
=
seq_idx
,
chunk_indices
=
chunk_indices
,
chunk_offsets
=
chunk_offsets
,
has_initial_states_p
=
has_initial_states_p
,
seq_idx_p
=
seq_idx_p
,
chunk_indices_p
=
chunk_indices_p
,
chunk_offsets_p
=
chunk_offsets_p
,
state_indices_tensor
=
state_indices_tensor
,
)
return
attn_metadata
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment