Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f42fc53
Unverified
Commit
5f42fc53
authored
Oct 03, 2025
by
Paul Pak
Committed by
GitHub
Oct 03, 2025
Browse files
[backends][short_conv] CUDA graph piecewise edits (#24215)
Signed-off-by:
Paul Pak
<
paulpak58@gmail.com
>
parent
8ee846c2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
21 deletions
+21
-21
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+1
-1
vllm/v1/attention/backends/short_conv_attn.py
vllm/v1/attention/backends/short_conv_attn.py
+20
-20
No files found.
vllm/model_executor/layers/mamba/short_conv.py
View file @
5f42fc53
...
@@ -115,7 +115,7 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -115,7 +115,7 @@ class ShortConv(MambaBase, CustomOp):
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states
has_initial_states_p
=
attn_metadata
.
has_initial_states
_p
BCx
,
_
=
self
.
in_proj
(
hidden_states
)
BCx
,
_
=
self
.
in_proj
(
hidden_states
)
...
...
vllm/v1/attention/backends/short_conv_attn.py
View file @
5f42fc53
...
@@ -6,12 +6,12 @@ from typing import Optional
...
@@ -6,12 +6,12 @@ from typing import Optional
import
torch
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.config
import
VllmConfig
from
vllm.v1.attention.backends.mamba_attn
import
(
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
BaseMambaAttentionMetadataBuilder
)
from
vllm.v1.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
compute_causal_conv1d_metadata
,
compute_causal_conv1d_metadata
,
split_decodes_and_prefills
)
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
class
ShortConvAttentionBackend
(
AttentionBackend
):
class
ShortConvAttentionBackend
(
AttentionBackend
):
...
@@ -29,8 +29,8 @@ class ShortConvAttentionMetadata:
...
@@ -29,8 +29,8 @@ class ShortConvAttentionMetadata:
num_decode_tokens
:
int
num_decode_tokens
:
int
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
has_initial_states
:
torch
.
Tensor
state_indices_tensor
:
torch
.
Tensor
state_indices_tensor
:
torch
.
Tensor
# shape: [batch,
]
has_initial_states_p
:
Optional
[
torch
.
Tensor
]
# For causal_conv1d
# For causal_conv1d
nums_dict
:
Optional
[
dict
]
=
None
nums_dict
:
Optional
[
dict
]
=
None
...
@@ -39,14 +39,7 @@ class ShortConvAttentionMetadata:
...
@@ -39,14 +39,7 @@ class ShortConvAttentionMetadata:
class
ShortConvAttentionMetadataBuilder
(
class
ShortConvAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
ShortConvAttentionMetadata
]):
BaseMambaAttentionMetadataBuilder
[
ShortConvAttentionMetadata
]):
reorder_batch_threshold
:
int
=
1
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
assert
isinstance
(
kv_cache_spec
,
MambaSpec
)
def
build
(
self
,
def
build
(
self
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
...
@@ -54,7 +47,6 @@ class ShortConvAttentionMetadataBuilder(
...
@@ -54,7 +47,6 @@ class ShortConvAttentionMetadataBuilder(
fast_build
:
bool
=
False
)
->
ShortConvAttentionMetadata
:
fast_build
:
bool
=
False
)
->
ShortConvAttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_reqs
=
common_attn_metadata
.
num_reqs
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
state_indices_tensor
=
common_attn_metadata
.
block_table_tensor
[:,
0
]
# for causal_conv1d
# for causal_conv1d
...
@@ -64,13 +56,13 @@ class ShortConvAttentionMetadataBuilder(
...
@@ -64,13 +56,13 @@ class ShortConvAttentionMetadataBuilder(
split_decodes_and_prefills
(
split_decodes_and_prefills
(
common_attn_metadata
,
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
))
decode_threshold
=
self
.
reorder_batch_threshold
))
has_initial_states
=
None
has_initial_states_p
=
None
if
num_prefills
>
0
:
if
num_prefills
>
0
:
#[batch,]
has_initial_states_cpu
=
(
has_initial_states_cpu
=
(
common_attn_metadata
.
common_attn_metadata
.
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
has_initial_states
=
has_initial_states_cpu
.
to
(
has_initial_states
_p
=
has_initial_states_cpu
.
to
(
query_start_loc
.
device
)
query_start_loc
.
device
)
query_start_loc_p
=
common_attn_metadata
.
query_start_loc
[
query_start_loc_p
=
common_attn_metadata
.
query_start_loc
[
...
@@ -79,14 +71,22 @@ class ShortConvAttentionMetadataBuilder(
...
@@ -79,14 +71,22 @@ class ShortConvAttentionMetadataBuilder(
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
\
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
\
compute_causal_conv1d_metadata
(
query_start_loc_p
)
compute_causal_conv1d_metadata
(
query_start_loc_p
)
elif
(
num_decodes
>
0
and
num_decodes
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
full_cuda_graph
):
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
)
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_input_tokens
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
attn_metadata
=
ShortConvAttentionMetadata
(
attn_metadata
=
ShortConvAttentionMetadata
(
query_start_loc
=
query_start_loc
,
state_indices_tensor
=
state_indices_tensor
,
has_initial_states_p
=
has_initial_states_p
,
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decodes
=
num_decodes
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
query_start_loc
=
query_start_loc
,
has_initial_states
=
has_initial_states
,
state_indices_tensor
=
state_indices_tensor
,
nums_dict
=
nums_dict
,
nums_dict
=
nums_dict
,
batch_ptr
=
batch_ptr
,
batch_ptr
=
batch_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
token_chunk_offset_ptr
=
token_chunk_offset_ptr
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment