test_attention_backends_selection.py 3.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""

from types import SimpleNamespace

import pytest

from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.short_conv import ShortConv
12
from vllm.model_executor.models.minimax_text_01 import MiniMaxText01LinearAttention
13
14
15
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
16
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
17
18
19


@pytest.mark.parametrize(
20
21
    "layer_class, init_kwargs, expected_backend, expected_mamba_type",
    [
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        (
            MambaMixer,
            dict(
                hidden_size=128,
                ssm_state_size=16,
                conv_kernel_size=4,
                intermediate_size=256,
                time_step_rank=8,
                use_conv_bias=True,
                use_bias=False,
                use_rms_norm=True,
            ),
            Mamba1AttentionBackend,
            "mamba1",
        ),
        (
            MambaMixer2,
            dict(
                hidden_size=128,
                ssm_state_size=16,
                conv_kernel_size=4,
                intermediate_size=256,
                use_conv_bias=True,
                use_bias=False,
                n_groups=1,
                num_heads=8,
                head_dim=32,
            ),
            Mamba2AttentionBackend,
            "mamba2",
        ),
        (
            MiniMaxText01LinearAttention,
            dict(
                hidden_size=128,
                hidden_inner_size=256,
                num_heads=8,
                head_dim=32,
                max_position=2048,
                block_size=64,
                num_hidden_layer=12,
                layer_idx=0,
                linear_layer_idx=0,
            ),
            LinearAttentionBackend,
            "linear_attention",
        ),
        (
            ShortConv,
            dict(
                config=SimpleNamespace(conv_L_cache=32, conv_bias=True),
                dim=128,
                layer_idx=0,
            ),
            ShortConvAttentionBackend,
            "short_conv",
        ),
79
80
81
82
83
    ],
)
def test_mamba_layers_get_attn_backend(
    dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type
):
84
85
86
87
88
89
90
91
    """Test that Mamba-like layers return the correct attention backend."""
    layer = layer_class(**init_kwargs)

    backend_class = layer.get_attn_backend()
    assert backend_class is expected_backend
    assert layer.mamba_type == expected_mamba_type


92
93
94
95
96
97
98
99
100
101
102
103
104
@pytest.mark.parametrize(
    "layer_class,expected_backend,expected_mamba_type",
    [
        (MambaMixer, Mamba1AttentionBackend, "mamba1"),
        (MambaMixer2, Mamba2AttentionBackend, "mamba2"),
        (MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
        (ShortConv, ShortConvAttentionBackend, "short_conv"),
    ],
)
def test_mamba_layers_have_unified_interface(
    layer_class, expected_backend, expected_mamba_type
):
    """Test that all Mamba layers have the unified get_attn_backend
105
    interface."""
106
107
108
109
110
111
    assert hasattr(layer_class, "get_attn_backend"), (
        f"{layer_class.__name__} should have get_attn_backend method"
    )
    assert hasattr(layer_class, "mamba_type"), (
        f"{layer_class.__name__} should have mamba_type property"
    )