Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5c4b6e66
Unverified
Commit
5c4b6e66
authored
Aug 25, 2025
by
Ayush Satyam
Committed by
GitHub
Aug 25, 2025
Browse files
[Attention] Unify mamba and attention backend selection (#23171)
Signed-off-by:
Ayush Satyam
<
ayushsatyam146@gmail.com
>
parent
d0a4a3f6
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
186 additions
and
72 deletions
+186
-72
tests/v1/attention/test_attention_backends_selection.py
tests/v1/attention/test_attention_backends_selection.py
+104
-0
tests/v1/attention/test_mamba_selectors.py
tests/v1/attention/test_mamba_selectors.py
+0
-25
vllm/attention/layer.py
vllm/attention/layer.py
+2
-1
vllm/model_executor/layers/attention_layer_base.py
vllm/model_executor/layers/attention_layer_base.py
+23
-0
vllm/model_executor/layers/mamba/abstract.py
vllm/model_executor/layers/mamba/abstract.py
+13
-2
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+9
-1
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+9
-1
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+9
-1
vllm/model_executor/models/minimax_text_01.py
vllm/model_executor/models/minimax_text_01.py
+9
-1
vllm/v1/attention/backends/mamba_selectors.py
vllm/v1/attention/backends/mamba_selectors.py
+0
-22
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-18
No files found.
tests/v1/attention/test_attention_backends_selection.py
0 → 100644
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
from
types
import
SimpleNamespace
import
pytest
from
vllm.model_executor.layers.mamba.mamba_mixer
import
MambaMixer
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.short_conv
import
ShortConv
from
vllm.model_executor.models.minimax_text_01
import
(
MiniMaxText01LinearAttention
)
from
vllm.v1.attention.backends.linear_attn
import
LinearAttentionBackend
from
vllm.v1.attention.backends.mamba1_attn
import
Mamba1AttentionBackend
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.short_conv_attn
import
(
ShortConvAttentionBackend
)
@
pytest
.
mark
.
parametrize
(
"layer_class, init_kwargs, expected_backend, expected_mamba_type"
,
[
(
MambaMixer
,
dict
(
hidden_size
=
128
,
ssm_state_size
=
16
,
conv_kernel_size
=
4
,
intermediate_size
=
256
,
time_step_rank
=
8
,
use_conv_bias
=
True
,
use_bias
=
False
,
use_rms_norm
=
True
,
),
Mamba1AttentionBackend
,
"mamba1"
,
),
(
MambaMixer2
,
dict
(
hidden_size
=
128
,
ssm_state_size
=
16
,
conv_kernel_size
=
4
,
intermediate_size
=
256
,
use_conv_bias
=
True
,
use_bias
=
False
,
n_groups
=
1
,
num_heads
=
8
,
head_dim
=
32
,
),
Mamba2AttentionBackend
,
"mamba2"
,
),
(
MiniMaxText01LinearAttention
,
dict
(
hidden_size
=
128
,
hidden_inner_size
=
256
,
num_heads
=
8
,
head_dim
=
32
,
max_position
=
2048
,
block_size
=
64
,
num_hidden_layer
=
12
,
layer_idx
=
0
,
linear_layer_idx
=
0
,
),
LinearAttentionBackend
,
"linear_attention"
,
),
(
ShortConv
,
dict
(
config
=
SimpleNamespace
(
conv_L_cache
=
32
,
conv_bias
=
True
),
dim
=
128
,
layer_idx
=
0
,
),
ShortConvAttentionBackend
,
"short_conv"
,
),
])
def
test_mamba_layers_get_attn_backend
(
dist_init
,
layer_class
,
init_kwargs
,
expected_backend
,
expected_mamba_type
):
"""Test that Mamba-like layers return the correct attention backend."""
layer
=
layer_class
(
**
init_kwargs
)
backend_class
=
layer
.
get_attn_backend
()
assert
backend_class
is
expected_backend
assert
layer
.
mamba_type
==
expected_mamba_type
@
pytest
.
mark
.
parametrize
(
"layer_class,expected_backend,expected_mamba_type"
,
[
(
MambaMixer
,
Mamba1AttentionBackend
,
"mamba1"
),
(
MambaMixer2
,
Mamba2AttentionBackend
,
"mamba2"
),
(
MiniMaxText01LinearAttention
,
LinearAttentionBackend
,
"linear_attention"
),
(
ShortConv
,
ShortConvAttentionBackend
,
"short_conv"
),
])
def
test_mamba_layers_have_unified_interface
(
layer_class
,
expected_backend
,
expected_mamba_type
):
"""Test that all Mamba layers have the unified get_attn_backend
interface."""
assert
hasattr
(
layer_class
,
'get_attn_backend'
),
(
f
"
{
layer_class
.
__name__
}
should have get_attn_backend method"
)
assert
hasattr
(
layer_class
,
'mamba_type'
),
(
f
"
{
layer_class
.
__name__
}
should have mamba_type property"
)
tests/v1/attention/test_mamba_selectors.py
deleted
100644 → 0
View file @
d0a4a3f6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
import
pytest
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.mamba_selectors
import
get_mamba_attn_backend
@
pytest
.
mark
.
parametrize
(
argnames
=
[
"mamba_type"
,
"expected_backend"
],
argvalues
=
[(
"mamba2"
,
Mamba2AttentionBackend
)])
def
test_get_mamba_attn_backend_mamba2
(
mamba_type
,
expected_backend
):
backend_class
=
get_mamba_attn_backend
(
mamba_type
)
assert
backend_class
is
expected_backend
def
test_get_mamba_attn_backend_unsupported
():
unsupported_types
=
[
"mamba"
,
""
]
for
mamba_type
in
unsupported_types
:
err_message
=
f
"Mamba Attention type
{
mamba_type
}
is not supported yet."
with
pytest
.
raises
(
NotImplementedError
,
match
=
err_message
):
get_mamba_attn_backend
(
mamba_type
)
vllm/attention/layer.py
View file @
5c4b6e66
...
...
@@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
...
...
@@ -54,7 +55,7 @@ def check_xformers_availability():
return
USE_XFORMERS_OPS
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
,
AttentionLayerBase
):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
...
...
vllm/model_executor/layers/attention_layer_base.py
0 → 100644
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for attention-like layers."""
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
class
AttentionLayerBase
(
ABC
):
"""
Base class for attention-like layers (Attention, Mamba, etc.)
that support the v1 engine.
This provides a common interface for getting attention backends
from different layer types.
"""
@
abstractmethod
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
"""Get the attention backend class for this layer."""
pass
vllm/model_executor/layers/mamba/abstract.py
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
abstractmethod
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
import
torch
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
class
MambaBase
(
ABC
):
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
class
MambaBase
(
AttentionLayerBase
):
"""
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
...
...
@@ -32,3 +38,8 @@ class MambaBase(ABC):
@
abstractmethod
def
mamba_type
(
self
)
->
str
:
pass
@
abstractmethod
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
"""Get the attention backend class for this Mamba layer."""
pass
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
NamedTuple
,
Optional
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Optional
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
import
torch
from
torch
import
nn
...
...
@@ -404,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp):
def
mamba_type
(
self
)
->
str
:
return
"mamba1"
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
from
vllm.v1.attention.backends.mamba1_attn
import
(
Mamba1AttentionBackend
)
return
Mamba1AttentionBackend
def
_time_proj_bias
(
self
)
->
Optional
[
torch
.
Tensor
]:
if
hasattr
(
self
.
dt_proj
,
"bias"
)
and
self
.
dt_proj
.
bias
is
not
None
:
return
self
.
dt_proj
.
bias
.
float
()
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
import
torch
from
torch
import
nn
...
...
@@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp):
def
mamba_type
(
self
)
->
str
:
return
"mamba2"
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
from
vllm.v1.attention.backends.mamba2_attn
import
(
Mamba2AttentionBackend
)
return
Mamba2AttentionBackend
def
mamba_mixer2
(
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/mamba/short_conv.py
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
typing
import
TYPE_CHECKING
,
Optional
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
import
torch
...
...
@@ -232,6 +235,11 @@ class ShortConv(MambaBase, CustomOp):
def
mamba_type
(
self
)
->
str
:
return
"short_conv"
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
from
vllm.v1.attention.backends.short_conv_attn
import
(
ShortConvAttentionBackend
)
return
ShortConvAttentionBackend
def
short_conv
(
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/minimax_text_01.py
View file @
5c4b6e66
...
...
@@ -4,7 +4,10 @@
import
copy
import
math
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
import
regex
as
re
import
torch
...
...
@@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def
mamba_type
(
self
)
->
str
:
return
"linear_attention"
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
from
vllm.v1.attention.backends.linear_attn
import
(
LinearAttentionBackend
)
return
LinearAttentionBackend
def
get_state_dtype
(
self
)
->
tuple
[
torch
.
dtype
]:
return
MambaStateDtypeCalculator
.
linear_attention_state_dtype
(
self
.
model_config
.
dtype
,
...
...
vllm/v1/attention/backends/mamba_selectors.py
deleted
100644 → 0
View file @
d0a4a3f6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.v1.attention.backends.linear_attn
import
LinearAttentionBackend
from
vllm.v1.attention.backends.mamba1_attn
import
Mamba1AttentionBackend
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.short_conv_attn
import
(
ShortConvAttentionBackend
)
def
get_mamba_attn_backend
(
mamba_type
:
str
)
->
type
[
AttentionBackend
]:
if
mamba_type
==
"mamba1"
:
return
Mamba1AttentionBackend
if
mamba_type
==
"mamba2"
:
return
Mamba2AttentionBackend
if
mamba_type
==
"linear_attention"
:
return
LinearAttentionBackend
if
mamba_type
==
"short_conv"
:
return
ShortConvAttentionBackend
raise
NotImplementedError
(
f
"Mamba Attention type
{
mamba_type
}
is not "
"supported yet."
)
vllm/v1/worker/gpu_model_runner.py
View file @
5c4b6e66
...
...
@@ -35,7 +35,8 @@ from vllm.distributed.parallel_state import (
from
vllm.forward_context
import
(
BatchDescriptor
,
DPMetadata
,
set_forward_context
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.models.interfaces
import
(
is_mixture_of_experts
,
...
...
@@ -55,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes
,
LazyLoader
,
cdiv
,
check_use_alibi
,
get_dtype_size
,
is_pin_memory_available
,
round_up
,
supports_dynamo
)
from
vllm.v1.attention.backends.mamba_selectors
import
get_mamba_attn_backend
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
make_kv_sharing_fast_prefill_attention_metadata
,
...
...
@@ -2747,11 +2747,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"""
assert
len
(
self
.
attn_groups
)
==
0
,
\
"Attention backends are already initialized"
attn_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
def
get_attn_backends_for_layers
(
layer_names
:
list
[
str
]
)
->
dict
[
type
[
AttentionBackend
],
list
[
str
]]:
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
,
layer_names
)
attn_backends
=
{}
attn_backend_layers
=
defaultdict
(
list
)
# Dedupe based on full class name; this is a bit safer than using
...
...
@@ -2760,7 +2762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# they are cached correctly, there will be different objects per
# layer.
for
layer_name
in
layer_names
:
attn_backend
=
attn_
layers
[
layer_name
].
get_attn_backend
()
attn_backend
=
layers
[
layer_name
].
get_attn_backend
()
key
=
attn_backend
.
full_cls_name
()
attn_backends
[
key
]
=
attn_backend
attn_backend_layers
[
key
].
append
(
layer_name
)
...
...
@@ -2789,20 +2791,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for
kv_cache_group_spec
in
kv_cache_config
.
kv_cache_groups
:
kv_cache_spec
=
kv_cache_group_spec
.
kv_cache_spec
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
attn_backends
=
get_attn_backends_for_layers
(
kv_cache_group_spec
.
layer_names
)
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
# layers like above
elif
isinstance
(
kv_cache_spec
,
MambaSpec
):
attn_backends
=
{
get_mamba_attn_backend
(
kv_cache_spec
.
mamba_type
):
kv_cache_group_spec
.
layer_names
}
else
:
raise
ValueError
(
f
"Unknown KV cache spec type:
{
type
(
kv_cache_spec
)
}
"
)
attn_backends
=
get_attn_backends_for_layers
(
kv_cache_group_spec
.
layer_names
)
self
.
attn_groups
.
append
(
create_attn_groups
(
attn_backends
,
kv_cache_spec
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment