test_attention_selector.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
from vllm.attention.backends.registry import AttentionBackendEnum
10
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
11
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
12
from vllm.platforms import current_platform
13
14
15
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
16
17


18
19
@pytest.fixture(autouse=True)
def clear_cache():
20
    """Clear lru cache to ensure each test case runs without caching."""
21
22
23
    _cached_get_attn_backend.cache_clear()


24
25
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
26
    "cuda": [
27
28
29
30
31
        "TRITON_MLA",
        "FLASHMLA",
        "FLASHINFER_MLA",
        "FLASH_ATTN_MLA",
        "CUTLASS_MLA",
32
    ],
33
34
35
36
37
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
38
    "cuda": ["FLASHINFER", "FLASH_ATTN"],
39
    "hip": ["ROCM_ATTN"],
40
    "cpu": ["CPU_ATTN"],
41
42
43
44
45
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
46
    # "cpu": [16]  # CPU uses fixed block size from test cases
47
    "cpu": [],  # FIXME(woosuk): Temporarily disable CPU tests
48
49
50
51
}


def generate_params():
52
    is_rocm = current_platform.is_rocm()
53
    params = []
54
    device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
55
    for use_mla in [True, False]:
56
        for device in device_list:
57
58
59
60
61
            backends = (
                DEVICE_MLA_BACKENDS[device]
                if use_mla
                else DEVICE_REGULAR_ATTN_BACKENDS[device]
            )
62
            for name in backends:
63
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
64
65
66
67
68
69
70
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
71
72
73
                            id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
                        )
                    )
74
75
76
    return params


77
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
78
def test_backend_selection(
79
    device: str,
80
    name: str,
81
82
    use_mla: bool,
    block_size: int,
83
):
84
    """Test attention backend selection with valid device-backend pairs."""
85
86
87
    # Create AttentionConfig with the specified backend
    attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
    vllm_config = VllmConfig(attention_config=attention_config)
88

89
    with set_current_vllm_config(vllm_config):
90
        if device == "cpu":
91
            with patch("vllm.platforms.current_platform", CpuPlatform()):
92
                backend = get_attn_backend(16, torch.float16, None, block_size)
93
            assert backend.get_name() == "CPU_ATTN"
94

95
        elif device == "hip":
96
            with patch("vllm.platforms.current_platform", RocmPlatform()):
97
                if use_mla:
98
99
100
101
102
103
104
105
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
106
                        with pytest.raises(ValueError) as exc_info:
107
108
109
110
                            get_attn_backend(
                                16, torch.float16, None, block_size, use_mla=use_mla
                            )
                        assert f"The selected backend, {name}" in str(exc_info.value)
111
112
                    else:
                        # Valid backend-block_size combination
113
114
115
                        backend = get_attn_backend(
                            16, torch.float16, None, block_size, use_mla=use_mla
                        )
116
                        expected = name
117
                        assert backend.get_name() == expected
118
                else:
119
120
121
                    backend = get_attn_backend(
                        16, torch.float16, None, block_size, use_mla=use_mla
                    )
122
                    expected = "ROCM_ATTN"
123
124
125
                    assert backend.get_name() == expected

        elif device == "cuda":
126
            with patch("vllm.platforms.current_platform", CudaPlatform()):
127
                capability = torch.cuda.get_device_capability()
128
                if use_mla:
129
130
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
131
                    #   and Blackwell GPUs (SM 10.x), V1 only
132
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
133
                    #   (SM 10.x), V1 only
134
135
136
137
138
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
139
                        if block_size != 128:
140
                            # CUTLASS_MLA only supports block_size == 128
141
                            pytest.skip("CUTLASS_MLA only supports block_size 128")
142
143
144
145
146
147
148
                        if capability[0] != 10:
                            pytest.skip("CUTLASS MLA is not supported on this platform")
                        backend = get_attn_backend(
                            576, torch.float16, None, block_size, use_mla=use_mla
                        )
                        expected = "CUTLASS_MLA"
                        assert backend.get_name() == expected
149
                    elif name == "FLASHINFER_MLA":
150
151
152
153
                        if capability[0] != 10:
                            pytest.skip(
                                "FlashInfer MLA is not supported on this platform"
                            )
154
                        if block_size not in [32, 64]:
155
156
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
157
158
                                "FlashInfer MLA only supports block_size 32 or 64"
                            )
159
160
161
162
163
                        backend = get_attn_backend(
                            576, torch.float16, None, block_size, use_mla=use_mla
                        )
                        expected = "FLASHINFER_MLA"
                        assert backend.get_name() == expected
164
165
166
167
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
168
169
170
                        from vllm.v1.attention.backends.mla.flashmla import (
                            is_flashmla_dense_supported,
                        )
171

172
173
174
175
176
177
178
179
180
181
182
183
                        is_supported, _ = is_flashmla_dense_supported()
                        if not is_supported:
                            pytest.skip("FlashMLA not supported on this platform")
                        backend = get_attn_backend(
                            576,
                            torch.float16,
                            None,
                            block_size,
                            use_mla=use_mla,
                        )
                        expected = name
                        assert backend.get_name() == expected
184
                    elif name == "FLASH_ATTN_MLA":
185
186
187
188
189
190
191
192
                        from vllm.attention.utils.fa_utils import (
                            flash_attn_supports_mla,
                        )

                        if not flash_attn_supports_mla():
                            pytest.skip(
                                "FlashAttention MLA not supported on this platform"
                            )
193
                        backend = get_attn_backend(
194
                            576, torch.float16, None, block_size, use_mla=use_mla
195
                        )
196
197
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
198
                    else:
199
                        # TRITON_MLA or other fallback
200
                        backend = get_attn_backend(
201
                            576, torch.float16, None, block_size, use_mla=use_mla
202
                        )
203
                        expected = "TRITON_MLA"
204
                        assert backend.get_name() == expected
205
                elif name == "FLASHINFER":
206
                    backend = get_attn_backend(
207
                        64, torch.float16, None, block_size, use_mla=use_mla
208
                    )
209
                    expected = "FLASHINFER"
210
                    assert backend.get_name() == expected
211
                elif name == "FLASH_ATTN":
212
213
214
                    backend = get_attn_backend(
                        32, torch.float16, None, block_size, use_mla=use_mla
                    )
215
216
                    expected = "FLASH_ATTN"
                    assert backend.get_name() == expected
217

218

219
@pytest.mark.parametrize("device", ["cpu", "cuda"])
220
def test_fp32_fallback(device: str):
221
    """Test attention backend selection with fp32."""
222
223
    # Use default config (no backend specified)
    vllm_config = VllmConfig()
224

225
226
227
228
229
230
231
232
233
234
    with set_current_vllm_config(vllm_config):
        if device == "cpu":
            with patch("vllm.platforms.current_platform", CpuPlatform()):
                backend = get_attn_backend(16, torch.float32, None, 16)
            assert backend.get_name() == "CPU_ATTN"

        elif device == "cuda":
            with patch("vllm.platforms.current_platform", CudaPlatform()):
                backend = get_attn_backend(16, torch.float32, None, 16)
            assert backend.get_name() == "FLEX_ATTENTION"
235
236


237
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
238
    """Test FlashAttn validation."""
239
240
    pytest.skip(
        "Skipping as current backend selector does not "
241
        "handle fallbacks when a backend is explicitly set."
242
    )
243

244
245
    attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
    vllm_config = VllmConfig(attention_config=attention_config)
246

247
    with set_current_vllm_config(vllm_config):
248
        # Unsupported CUDA arch
249
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
250
        backend = get_attn_backend(16, torch.float16, None, 16)
251
        assert backend.get_name() != "FLASH_ATTN"
252

253
254
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
255

256
        # Unsupported data type
257
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
258
        assert backend.get_name() != "FLASH_ATTN"
259

260
        # Unsupported kv cache data type
261
        backend = get_attn_backend(16, torch.float16, "fp8", 16)
262
        assert backend.get_name() != "FLASH_ATTN"
263

264
        # Unsupported block size
265
        backend = get_attn_backend(16, torch.float16, None, 8)
266
        assert backend.get_name() != "FLASH_ATTN"
267
268
269

        # flash-attn is not installed
        import sys
270
271
272

        original_module = sys.modules.get("vllm_flash_attn")
        monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
273
        backend = get_attn_backend(16, torch.float16, None, 16)
274
        assert backend.get_name() != "FLASH_ATTN"
275

276
277
        # Restore the original module if it existed
        if original_module is not None:
278
            monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
279
        else:
280
            monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
281

282
        # Unsupported head size
283
        backend = get_attn_backend(17, torch.float16, None, 16)
284
        assert backend.get_name() != "FLASH_ATTN"
285
286


287
def test_invalid_backend():
288
    """Test that invalid attention backend names raise ValueError."""
289
    with (
290
        pytest.raises(ValueError),
291
    ):
292
293
        # Invalid backend name should raise ValueError when creating enum
        AttentionConfig(backend=AttentionBackendEnum["INVALID"])