Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
072d7e53
Unverified
Commit
072d7e53
authored
Sep 18, 2025
by
Vadim Gimpelson
Committed by
GitHub
Sep 18, 2025
Browse files
[PERF] Add `conv1d` metadata to GDN attn (#25105)
Signed-off-by:
Vadim Gimpelson
<
vadim.gimpelson@gmail.com
>
parent
01a583fe
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
24 additions
and
8 deletions
+24
-8
vllm/model_executor/layers/mamba/mamba2_metadata.py
vllm/model_executor/layers/mamba/mamba2_metadata.py
+5
-3
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+9
-1
vllm/v1/attention/backends/gdn_attn.py
vllm/v1/attention/backends/gdn_attn.py
+6
-0
vllm/v1/attention/backends/mamba2_attn.py
vllm/v1/attention/backends/mamba2_attn.py
+2
-2
vllm/v1/attention/backends/short_conv_attn.py
vllm/v1/attention/backends/short_conv_attn.py
+2
-2
No files found.
vllm/model_executor/layers/mamba/mamba2_metadata.py
View file @
072d7e53
...
...
@@ -11,6 +11,7 @@ from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata
)
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadata
from
vllm.v1.attention.backends.mamba2_attn
import
(
Mamba2AttentionMetadata
,
_query_start_loc_to_chunk_indices_offsets
)
...
...
@@ -45,8 +46,8 @@ class Mamba2Metadata:
"""
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
t
ensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
t
ensor
]
=
None
batch_ptr
:
Optional
[
torch
.
T
ensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
T
ensor
]
=
None
def
get_platform_metadata_classes
()
->
tuple
[
type
[
AttentionMetadata
],
...]:
...
...
@@ -117,7 +118,8 @@ def prepare_mamba2_metadata(
def
update_metadata
(
x
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
mamba2_metadata
:
Union
[
Mamba2Metadata
,
Mamba2AttentionMetadata
]):
Mamba2AttentionMetadata
,
GDNAttentionMetadata
]):
"""
this is triggered upon handling a new input at the first layer
"""
...
...
vllm/model_executor/models/qwen3_next.py
View file @
072d7e53
...
...
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
update_metadata
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
mamba_v2_sharded_weight_loader
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
...
...
@@ -414,6 +415,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
conv_metadata
=
attn_metadata
assert
isinstance
(
attn_metadata
,
GDNAttentionMetadata
)
has_initial_state
=
attn_metadata
.
has_initial_state
spec_query_start_loc
=
attn_metadata
.
spec_query_start_loc
...
...
@@ -475,10 +477,15 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 2.2: process the remaining part
if
attn_metadata
.
num_prefills
>
0
:
mixed_qkv_non_spec_T
=
mixed_qkv_non_spec
.
transpose
(
0
,
1
)
if
conv_metadata
.
cu_seqlen
is
None
:
conv_metadata
=
update_metadata
(
mixed_qkv_non_spec_T
,
non_spec_query_start_loc
,
conv_metadata
)
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
mixed_qkv_non_spec
=
causal_conv1d_fn
(
mixed_qkv_non_spec
.
transpose
(
0
,
1
)
,
mixed_qkv_non_spec
_T
,
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
...
...
@@ -486,6 +493,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
has_initial_state
=
has_initial_state
,
cache_indices
=
non_spec_state_indices_tensor
,
query_start_loc
=
non_spec_query_start_loc
,
metadata
=
conv_metadata
,
).
transpose
(
0
,
1
)
elif
attn_metadata
.
num_decodes
>
0
:
mixed_qkv_non_spec
=
causal_conv1d_update
(
...
...
vllm/v1/attention/backends/gdn_attn.py
View file @
072d7e53
...
...
@@ -50,6 +50,12 @@ class GDNAttentionMetadata:
Tensor
]
=
None
# shape: [num_prefill_tokens + num_decode_tokens,]
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
Tensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
Tensor
]
=
None
class
GDNAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
GDNAttentionMetadata
]):
...
...
vllm/v1/attention/backends/mamba2_attn.py
View file @
072d7e53
...
...
@@ -132,8 +132,8 @@ class Mamba2AttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
t
ensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
t
ensor
]
=
None
batch_ptr
:
Optional
[
torch
.
T
ensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
T
ensor
]
=
None
class
Mamba2AttentionMetadataBuilder
(
...
...
vllm/v1/attention/backends/short_conv_attn.py
View file @
072d7e53
...
...
@@ -34,8 +34,8 @@ class ShortConvAttentionMetadata:
# For causal_conv1d
nums_dict
:
Optional
[
dict
]
=
None
cu_seqlen
:
Optional
[
int
]
=
None
batch_ptr
:
Optional
[
torch
.
t
ensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
t
ensor
]
=
None
batch_ptr
:
Optional
[
torch
.
T
ensor
]
=
None
token_chunk_offset_ptr
:
Optional
[
torch
.
T
ensor
]
=
None
class
ShortConvAttentionMetadataBuilder
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment