Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1d93f116
Unverified
Commit
1d93f116
authored
Dec 02, 2025
by
Matthew Bonanni
Committed by
GitHub
Dec 02, 2025
Browse files
[Attention][CUDAGraph] Remove CG padding from attention backends (#29352)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
2d613de9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
20 additions
and
46 deletions
+20
-46
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+8
-9
vllm/v1/attention/backends/gdn_attn.py
vllm/v1/attention/backends/gdn_attn.py
+5
-17
vllm/v1/attention/backends/mamba1_attn.py
vllm/v1/attention/backends/mamba1_attn.py
+3
-9
vllm/v1/attention/backends/mamba2_attn.py
vllm/v1/attention/backends/mamba2_attn.py
+3
-9
vllm/v1/attention/backends/short_conv_attn.py
vllm/v1/attention/backends/short_conv_attn.py
+1
-2
No files found.
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
1d93f116
...
...
@@ -252,7 +252,6 @@ class MambaMixer(MambaBase, CustomOp):
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
num_padded_decodes
=
attn_metadata
.
num_padded_decodes
# 1. Gated MLP's linear projection
projected_states
=
self
.
in_proj
(
hidden_states
)[
0
].
transpose
(
-
2
,
-
1
)
...
...
@@ -281,7 +280,7 @@ class MambaMixer(MambaBase, CustomOp):
state_indices_tensor
,
num_prefill_tokens
,
num_prefills
,
num_
padded_
decodes
,
num_decode
_token
s
,
)
hidden_states_BC_p
=
prefill_decode_split
.
hidden_states_BC_p
hidden_states_BC_d
=
prefill_decode_split
.
hidden_states_BC_d
...
...
@@ -470,24 +469,24 @@ def split_batch_to_prefill_and_decode(
state_indices_tensor
:
torch
.
Tensor
,
num_prefill_tokens
:
int
,
num_prefills
:
int
,
num_
padded_
decodes
:
int
,
num_decode
_token
s
:
int
,
)
->
PrefillDecodeSplit
:
num_actual_tokens
=
num_prefill_tokens
+
num_
padded_
decodes
num_actual_tokens
=
num_prefill_tokens
+
num_decode
_token
s
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d
,
hidden_states_BC_p
=
torch
.
split
(
hidden_states_BC
[...,
:
num_actual_tokens
],
[
num_
padded_
decodes
,
num_prefill_tokens
],
[
num_decode
_token
s
,
num_prefill_tokens
],
dim
=-
1
,
)
gate_d
,
gate_p
=
torch
.
split
(
gate
[...,
:
num_actual_tokens
],
[
num_
padded_
decodes
,
num_prefill_tokens
],
dim
=-
1
gate
[...,
:
num_actual_tokens
],
[
num_decode
_token
s
,
num_prefill_tokens
],
dim
=-
1
)
# num_
padded_
decodes accounts for CUDA graph padding when applicable
# num_decode
_token
s accounts for CUDA graph padding when applicable
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
[:
num_
padded_
decodes
+
num_prefills
],
[
num_
padded_
decodes
,
num_prefills
],
state_indices_tensor
[:
num_decode
_token
s
+
num_prefills
],
[
num_decode
_token
s
,
num_prefills
],
dim
=
0
,
)
...
...
vllm/v1/attention/backends/gdn_attn.py
View file @
1d93f116
...
...
@@ -254,17 +254,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
)
else
:
has_initial_state
=
None
num_actual_tokens
=
(
num_prefill_tokens
+
num_decode_tokens
+
num_spec_decode_tokens
)
# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
# Prepare tensors for cudagraph
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
batch_size
=
m
.
num_actual_tokens
if
(
self
.
use_full_cuda_graph
and
num_prefills
==
0
...
...
@@ -272,9 +266,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and
num_spec_decodes
<=
self
.
decode_cudagraph_max_bs
and
num_spec_decode_tokens
<=
self
.
decode_cudagraph_max_bs
):
num_actual_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
m
.
num_actual_tokens
)
batch_size
=
min
(
self
.
decode_cudagraph_max_bs
,
num_actual_tokens
)
self
.
spec_state_indices_tensor
[:
num_spec_decodes
].
copy_
(
spec_state_indices_tensor
,
non_blocking
=
True
)
...
...
@@ -319,9 +310,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and
num_spec_decodes
==
0
and
num_decodes
<=
self
.
decode_cudagraph_max_bs
):
num_actual_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
m
.
num_actual_tokens
)
batch_size
=
num_actual_tokens
self
.
non_spec_state_indices_tensor
[:
num_decodes
].
copy_
(
non_spec_state_indices_tensor
,
non_blocking
=
True
)
...
...
@@ -344,7 +332,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
num_decode_tokens
=
num_decode_tokens
,
num_spec_decodes
=
num_spec_decodes
,
num_spec_decode_tokens
=
num_spec_decode_tokens
,
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
m
.
num_actual_tokens
,
has_initial_state
=
has_initial_state
,
spec_query_start_loc
=
spec_query_start_loc
,
non_spec_query_start_loc
=
non_spec_query_start_loc
,
...
...
vllm/v1/attention/backends/mamba1_attn.py
View file @
1d93f116
...
...
@@ -31,7 +31,6 @@ class Mamba1AttentionMetadata:
num_prefill_tokens
:
int
num_decodes
:
int
num_decode_tokens
:
int
num_padded_decodes
:
int
block_idx_last_scheduled_token
:
torch
.
Tensor
# shape: [batch,]
block_idx_first_scheduled_token_p
:
torch
.
Tensor
# shape: [batch,]
...
...
@@ -68,7 +67,6 @@ class Mamba1AttentionMetadataBuilder(
has_initial_states_p
=
None
query_start_loc_p
=
None
padded_decodes
=
num_decodes
num_computed_tokens
,
num_computed_tokens_p
=
None
,
None
block_idx_first_scheduled_token
=
None
block_idx_first_scheduled_token_p
=
None
...
...
@@ -125,11 +123,10 @@ class Mamba1AttentionMetadataBuilder(
and
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
padded_decodes
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
)
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
padded
_decodes
]
state_indices_tensor
=
self
.
state_indices_tensor
[:
num
_decode
_token
s
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
...
...
@@ -137,17 +134,15 @@ class Mamba1AttentionMetadataBuilder(
block_idx_last_scheduled_token
,
non_blocking
=
True
)
block_idx_last_scheduled_token
=
self
.
block_idx_last_scheduled_token
[
:
padded
_decodes
:
num
_decode
_token
s
]
block_idx_last_scheduled_token
[
num_decodes
:]
=
0
self
.
block_idx_last_computed_token
[:
num_decodes
].
copy_
(
block_idx_last_computed_token
,
non_blocking
=
True
)
block_idx_last_computed_token
=
self
.
block_idx_last_computed_token
[
:
padded
_decodes
:
num
_decode
_token
s
]
block_idx_last_computed_token
[
num_decodes
:]
=
0
return
Mamba1AttentionMetadata
(
query_start_loc_p
=
query_start_loc_p
,
...
...
@@ -157,7 +152,6 @@ class Mamba1AttentionMetadataBuilder(
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
num_padded_decodes
=
padded_decodes
,
block_idx_last_scheduled_token
=
block_idx_last_scheduled_token
,
block_idx_first_scheduled_token_p
=
block_idx_first_scheduled_token_p
,
block_idx_last_computed_token
=
block_idx_last_computed_token
,
...
...
vllm/v1/attention/backends/mamba2_attn.py
View file @
1d93f116
...
...
@@ -10,7 +10,6 @@ from vllm.config import VllmConfig
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backends.mamba_attn
import
BaseMambaAttentionMetadataBuilder
from
vllm.v1.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
,
...
...
@@ -304,30 +303,25 @@ class Mamba2AttentionMetadataBuilder(
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
# Pad state tensor for CUDA graph
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
)
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_input_tokens
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_decode_tokens
]
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
self
.
block_idx_last_scheduled_token
[:
num_decodes
].
copy_
(
block_idx_last_scheduled_token
,
non_blocking
=
True
)
block_idx_last_scheduled_token
=
self
.
block_idx_last_scheduled_token
[
:
num_
input
_tokens
:
num_
decode
_tokens
]
block_idx_last_scheduled_token
[
num_decodes
:]
=
0
self
.
block_idx_last_computed_token
[:
num_decodes
].
copy_
(
block_idx_last_computed_token
,
non_blocking
=
True
)
block_idx_last_computed_token
=
self
.
block_idx_last_computed_token
[
:
num_
input
_tokens
:
num_
decode
_tokens
]
block_idx_last_computed_token
[
num_decodes
:]
=
0
attn_metadata
=
Mamba2AttentionMetadata
(
num_prefills
=
num_prefills
,
...
...
vllm/v1/attention/backends/short_conv_attn.py
View file @
1d93f116
...
...
@@ -83,11 +83,10 @@ class ShortConvAttentionMetadataBuilder(
and
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
)
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_
input
_tokens
]
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_
decode
_tokens
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
attn_metadata
=
ShortConvAttentionMetadata
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment