Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5c4b6e66
Unverified
Commit
5c4b6e66
authored
Aug 25, 2025
by
Ayush Satyam
Committed by
GitHub
Aug 25, 2025
Browse files
[Attention] Unify mamba and attention backend selection (#23171)
Signed-off-by:
Ayush Satyam
<
ayushsatyam146@gmail.com
>
parent
d0a4a3f6
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
186 additions
and
72 deletions
+186
-72
tests/v1/attention/test_attention_backends_selection.py
tests/v1/attention/test_attention_backends_selection.py
+104
-0
tests/v1/attention/test_mamba_selectors.py
tests/v1/attention/test_mamba_selectors.py
+0
-25
vllm/attention/layer.py
vllm/attention/layer.py
+2
-1
vllm/model_executor/layers/attention_layer_base.py
vllm/model_executor/layers/attention_layer_base.py
+23
-0
vllm/model_executor/layers/mamba/abstract.py
vllm/model_executor/layers/mamba/abstract.py
+13
-2
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+9
-1
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+9
-1
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+9
-1
vllm/model_executor/models/minimax_text_01.py
vllm/model_executor/models/minimax_text_01.py
+9
-1
vllm/v1/attention/backends/mamba_selectors.py
vllm/v1/attention/backends/mamba_selectors.py
+0
-22
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-18
No files found.
tests/v1/attention/test_attention_backends_selection.py
0 → 100644
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
from
types
import
SimpleNamespace
import
pytest
from
vllm.model_executor.layers.mamba.mamba_mixer
import
MambaMixer
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.short_conv
import
ShortConv
from
vllm.model_executor.models.minimax_text_01
import
(
MiniMaxText01LinearAttention
)
from
vllm.v1.attention.backends.linear_attn
import
LinearAttentionBackend
from
vllm.v1.attention.backends.mamba1_attn
import
Mamba1AttentionBackend
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.short_conv_attn
import
(
ShortConvAttentionBackend
)
@
pytest
.
mark
.
parametrize
(
"layer_class, init_kwargs, expected_backend, expected_mamba_type"
,
[
(
MambaMixer
,
dict
(
hidden_size
=
128
,
ssm_state_size
=
16
,
conv_kernel_size
=
4
,
intermediate_size
=
256
,
time_step_rank
=
8
,
use_conv_bias
=
True
,
use_bias
=
False
,
use_rms_norm
=
True
,
),
Mamba1AttentionBackend
,
"mamba1"
,
),
(
MambaMixer2
,
dict
(
hidden_size
=
128
,
ssm_state_size
=
16
,
conv_kernel_size
=
4
,
intermediate_size
=
256
,
use_conv_bias
=
True
,
use_bias
=
False
,
n_groups
=
1
,
num_heads
=
8
,
head_dim
=
32
,
),
Mamba2AttentionBackend
,
"mamba2"
,
),
(
MiniMaxText01LinearAttention
,
dict
(
hidden_size
=
128
,
hidden_inner_size
=
256
,
num_heads
=
8
,
head_dim
=
32
,
max_position
=
2048
,
block_size
=
64
,
num_hidden_layer
=
12
,
layer_idx
=
0
,
linear_layer_idx
=
0
,
),
LinearAttentionBackend
,
"linear_attention"
,
),
(
ShortConv
,
dict
(
config
=
SimpleNamespace
(
conv_L_cache
=
32
,
conv_bias
=
True
),
dim
=
128
,
layer_idx
=
0
,
),
ShortConvAttentionBackend
,
"short_conv"
,
),
])
def
test_mamba_layers_get_attn_backend
(
dist_init
,
layer_class
,
init_kwargs
,
expected_backend
,
expected_mamba_type
):
"""Test that Mamba-like layers return the correct attention backend."""
layer
=
layer_class
(
**
init_kwargs
)
backend_class
=
layer
.
get_attn_backend
()
assert
backend_class
is
expected_backend
assert
layer
.
mamba_type
==
expected_mamba_type
@
pytest
.
mark
.
parametrize
(
"layer_class,expected_backend,expected_mamba_type"
,
[
(
MambaMixer
,
Mamba1AttentionBackend
,
"mamba1"
),
(
MambaMixer2
,
Mamba2AttentionBackend
,
"mamba2"
),
(
MiniMaxText01LinearAttention
,
LinearAttentionBackend
,
"linear_attention"
),
(
ShortConv
,
ShortConvAttentionBackend
,
"short_conv"
),
])
def
test_mamba_layers_have_unified_interface
(
layer_class
,
expected_backend
,
expected_mamba_type
):
"""Test that all Mamba layers have the unified get_attn_backend
interface."""
assert
hasattr
(
layer_class
,
'get_attn_backend'
),
(
f
"
{
layer_class
.
__name__
}
should have get_attn_backend method"
)
assert
hasattr
(
layer_class
,
'mamba_type'
),
(
f
"
{
layer_class
.
__name__
}
should have mamba_type property"
)
tests/v1/attention/test_mamba_selectors.py
deleted
100644 → 0
View file @
d0a4a3f6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
import
pytest
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.mamba_selectors
import
get_mamba_attn_backend
@
pytest
.
mark
.
parametrize
(
argnames
=
[
"mamba_type"
,
"expected_backend"
],
argvalues
=
[(
"mamba2"
,
Mamba2AttentionBackend
)])
def
test_get_mamba_attn_backend_mamba2
(
mamba_type
,
expected_backend
):
backend_class
=
get_mamba_attn_backend
(
mamba_type
)
assert
backend_class
is
expected_backend
def
test_get_mamba_attn_backend_unsupported
():
unsupported_types
=
[
"mamba"
,
""
]
for
mamba_type
in
unsupported_types
:
err_message
=
f
"Mamba Attention type
{
mamba_type
}
is not supported yet."
with
pytest
.
raises
(
NotImplementedError
,
match
=
err_message
):
get_mamba_attn_backend
(
mamba_type
)
vllm/attention/layer.py
View file @
5c4b6e66
...
@@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
...
@@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group
)
is_v1_kv_transfer_group
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
...
@@ -54,7 +55,7 @@ def check_xformers_availability():
...
@@ -54,7 +55,7 @@ def check_xformers_availability():
return
USE_XFORMERS_OPS
return
USE_XFORMERS_OPS
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
,
AttentionLayerBase
):
"""Attention layer.
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
This class takes query, key, and value tensors as input. The input tensors
...
...
vllm/model_executor/layers/attention_layer_base.py
0 → 100644
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for attention-like layers."""
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
class
AttentionLayerBase
(
ABC
):
"""
Base class for attention-like layers (Attention, Mamba, etc.)
that support the v1 engine.
This provides a common interface for getting attention backends
from different layer types.
"""
@
abstractmethod
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
"""Get the attention backend class for this layer."""
pass
vllm/model_executor/layers/mamba/abstract.py
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
abstractmethod
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
import
torch
import
torch
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
class
MambaBase
(
ABC
):
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
class
MambaBase
(
AttentionLayerBase
):
"""
"""
Base class for Mamba-like layers which support the v1 engine.
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
Inherit from this class if you implement a custom layer.
...
@@ -32,3 +38,8 @@ class MambaBase(ABC):
...
@@ -32,3 +38,8 @@ class MambaBase(ABC):
@
abstractmethod
@
abstractmethod
def
mamba_type
(
self
)
->
str
:
def
mamba_type
(
self
)
->
str
:
pass
pass
@
abstractmethod
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
"""Get the attention backend class for this Mamba layer."""
pass
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
NamedTuple
,
Optional
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Optional
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -404,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp):
...
@@ -404,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp):
def
mamba_type
(
self
)
->
str
:
def
mamba_type
(
self
)
->
str
:
return
"mamba1"
return
"mamba1"
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
from
vllm.v1.attention.backends.mamba1_attn
import
(
Mamba1AttentionBackend
)
return
Mamba1AttentionBackend
def
_time_proj_bias
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
_time_proj_bias
(
self
)
->
Optional
[
torch
.
Tensor
]:
if
hasattr
(
self
.
dt_proj
,
"bias"
)
and
self
.
dt_proj
.
bias
is
not
None
:
if
hasattr
(
self
.
dt_proj
,
"bias"
)
and
self
.
dt_proj
.
bias
is
not
None
:
return
self
.
dt_proj
.
bias
.
float
()
return
self
.
dt_proj
.
bias
.
float
()
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp):
def
mamba_type
(
self
)
->
str
:
def
mamba_type
(
self
)
->
str
:
return
"mamba2"
return
"mamba2"
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
from
vllm.v1.attention.backends.mamba2_attn
import
(
Mamba2AttentionBackend
)
return
Mamba2AttentionBackend
def
mamba_mixer2
(
def
mamba_mixer2
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/mamba/short_conv.py
View file @
5c4b6e66
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
typing
import
TYPE_CHECKING
,
Optional
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
import
torch
import
torch
...
@@ -232,6 +235,11 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -232,6 +235,11 @@ class ShortConv(MambaBase, CustomOp):
def
mamba_type
(
self
)
->
str
:
def
mamba_type
(
self
)
->
str
:
return
"short_conv"
return
"short_conv"
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
from
vllm.v1.attention.backends.short_conv_attn
import
(
ShortConvAttentionBackend
)
return
ShortConvAttentionBackend
def
short_conv
(
def
short_conv
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/minimax_text_01.py
View file @
5c4b6e66
...
@@ -4,7 +4,10 @@
...
@@ -4,7 +4,10 @@
import
copy
import
copy
import
math
import
math
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
import
regex
as
re
import
regex
as
re
import
torch
import
torch
...
@@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def
mamba_type
(
self
)
->
str
:
def
mamba_type
(
self
)
->
str
:
return
"linear_attention"
return
"linear_attention"
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
from
vllm.v1.attention.backends.linear_attn
import
(
LinearAttentionBackend
)
return
LinearAttentionBackend
def
get_state_dtype
(
self
)
->
tuple
[
torch
.
dtype
]:
def
get_state_dtype
(
self
)
->
tuple
[
torch
.
dtype
]:
return
MambaStateDtypeCalculator
.
linear_attention_state_dtype
(
return
MambaStateDtypeCalculator
.
linear_attention_state_dtype
(
self
.
model_config
.
dtype
,
self
.
model_config
.
dtype
,
...
...
vllm/v1/attention/backends/mamba_selectors.py
deleted
100644 → 0
View file @
d0a4a3f6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.v1.attention.backends.linear_attn
import
LinearAttentionBackend
from
vllm.v1.attention.backends.mamba1_attn
import
Mamba1AttentionBackend
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.short_conv_attn
import
(
ShortConvAttentionBackend
)
def
get_mamba_attn_backend
(
mamba_type
:
str
)
->
type
[
AttentionBackend
]:
if
mamba_type
==
"mamba1"
:
return
Mamba1AttentionBackend
if
mamba_type
==
"mamba2"
:
return
Mamba2AttentionBackend
if
mamba_type
==
"linear_attention"
:
return
LinearAttentionBackend
if
mamba_type
==
"short_conv"
:
return
ShortConvAttentionBackend
raise
NotImplementedError
(
f
"Mamba Attention type
{
mamba_type
}
is not "
"supported yet."
)
vllm/v1/worker/gpu_model_runner.py
View file @
5c4b6e66
...
@@ -35,7 +35,8 @@ from vllm.distributed.parallel_state import (
...
@@ -35,7 +35,8 @@ from vllm.distributed.parallel_state import (
from
vllm.forward_context
import
(
BatchDescriptor
,
DPMetadata
,
from
vllm.forward_context
import
(
BatchDescriptor
,
DPMetadata
,
set_forward_context
)
set_forward_context
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.models.interfaces
import
(
is_mixture_of_experts
,
from
vllm.model_executor.models.interfaces
import
(
is_mixture_of_experts
,
...
@@ -55,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
...
@@ -55,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes
,
LazyLoader
,
cdiv
,
check_use_alibi
,
GiB_bytes
,
LazyLoader
,
cdiv
,
check_use_alibi
,
get_dtype_size
,
is_pin_memory_available
,
round_up
,
get_dtype_size
,
is_pin_memory_available
,
round_up
,
supports_dynamo
)
supports_dynamo
)
from
vllm.v1.attention.backends.mamba_selectors
import
get_mamba_attn_backend
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
make_kv_sharing_fast_prefill_attention_metadata
,
make_kv_sharing_fast_prefill_attention_metadata
,
...
@@ -2747,11 +2747,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2747,11 +2747,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"""
"""
assert
len
(
self
.
attn_groups
)
==
0
,
\
assert
len
(
self
.
attn_groups
)
==
0
,
\
"Attention backends are already initialized"
"Attention backends are already initialized"
attn_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
def
get_attn_backends_for_layers
(
def
get_attn_backends_for_layers
(
layer_names
:
list
[
str
]
layer_names
:
list
[
str
]
)
->
dict
[
type
[
AttentionBackend
],
list
[
str
]]:
)
->
dict
[
type
[
AttentionBackend
],
list
[
str
]]:
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
,
layer_names
)
attn_backends
=
{}
attn_backends
=
{}
attn_backend_layers
=
defaultdict
(
list
)
attn_backend_layers
=
defaultdict
(
list
)
# Dedupe based on full class name; this is a bit safer than using
# Dedupe based on full class name; this is a bit safer than using
...
@@ -2760,7 +2762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2760,7 +2762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# they are cached correctly, there will be different objects per
# they are cached correctly, there will be different objects per
# layer.
# layer.
for
layer_name
in
layer_names
:
for
layer_name
in
layer_names
:
attn_backend
=
attn_
layers
[
layer_name
].
get_attn_backend
()
attn_backend
=
layers
[
layer_name
].
get_attn_backend
()
key
=
attn_backend
.
full_cls_name
()
key
=
attn_backend
.
full_cls_name
()
attn_backends
[
key
]
=
attn_backend
attn_backends
[
key
]
=
attn_backend
attn_backend_layers
[
key
].
append
(
layer_name
)
attn_backend_layers
[
key
].
append
(
layer_name
)
...
@@ -2789,20 +2791,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2789,20 +2791,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for
kv_cache_group_spec
in
kv_cache_config
.
kv_cache_groups
:
for
kv_cache_group_spec
in
kv_cache_config
.
kv_cache_groups
:
kv_cache_spec
=
kv_cache_group_spec
.
kv_cache_spec
kv_cache_spec
=
kv_cache_group_spec
.
kv_cache_spec
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
attn_backends
=
get_attn_backends_for_layers
(
attn_backends
=
get_attn_backends_for_layers
(
kv_cache_group_spec
.
layer_names
)
kv_cache_group_spec
.
layer_names
)
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
# layers like above
elif
isinstance
(
kv_cache_spec
,
MambaSpec
):
attn_backends
=
{
get_mamba_attn_backend
(
kv_cache_spec
.
mamba_type
):
kv_cache_group_spec
.
layer_names
}
else
:
raise
ValueError
(
f
"Unknown KV cache spec type:
{
type
(
kv_cache_spec
)
}
"
)
self
.
attn_groups
.
append
(
self
.
attn_groups
.
append
(
create_attn_groups
(
attn_backends
,
kv_cache_spec
))
create_attn_groups
(
attn_backends
,
kv_cache_spec
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment