test_attention_selector.py 15.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
10
from vllm.platforms import current_platform
11
12
13
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
14
15
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
16
17


18
19
@pytest.fixture(autouse=True)
def clear_cache():
20
    """Clear lru cache to ensure each test case runs without caching."""
21
22
23
    _cached_get_attn_backend.cache_clear()


24
25
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
26
    "cuda": [
27
28
29
30
31
        "TRITON_MLA",
        "FLASHMLA",
        "FLASHINFER_MLA",
        "FLASH_ATTN_MLA",
        "CUTLASS_MLA",
32
    ],
33
34
35
36
37
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
38
    "cuda": ["FLASHINFER", "FLASH_ATTN"],
39
    "hip": ["ROCM_ATTN"],
40
    "cpu": ["CPU_ATTN"],
41
42
43
44
45
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
46
    # "cpu": [16]  # CPU uses fixed block size from test cases
47
    "cpu": [],  # FIXME(woosuk): Temporarily disable CPU tests
48
49
50
51
}


def generate_params():
52
    is_rocm = current_platform.is_rocm()
53
    params = []
54
    device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
55
    for use_mla in [True, False]:
56
        for device in device_list:
57
58
59
60
61
            backends = (
                DEVICE_MLA_BACKENDS[device]
                if use_mla
                else DEVICE_REGULAR_ATTN_BACKENDS[device]
            )
62
            for name in backends:
63
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
64
65
66
67
68
69
70
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
71
72
73
                            id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
                        )
                    )
74
75
76
    return params


77
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
78
def test_backend_selection(
79
    device: str,
80
    name: str,
81
82
    use_mla: bool,
    block_size: int,
83
):
84
    """Test attention backend selection with valid device-backend pairs."""
85
86
87
    # Create AttentionConfig with the specified backend
    attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
    vllm_config = VllmConfig(attention_config=attention_config)
88

89
    with set_current_vllm_config(vllm_config):
90
        if device == "cpu":
91
            with patch("vllm.platforms.current_platform", CpuPlatform()):
92
                backend = get_attn_backend(16, torch.float16, None, block_size)
93
            assert backend.get_name() == "CPU_ATTN"
94

95
        elif device == "hip":
96
            with patch("vllm.platforms.current_platform", RocmPlatform()):
97
                if use_mla:
98
99
100
101
102
103
104
105
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
106
                        with pytest.raises(ValueError) as exc_info:
107
108
109
110
                            get_attn_backend(
                                16, torch.float16, None, block_size, use_mla=use_mla
                            )
                        assert f"The selected backend, {name}" in str(exc_info.value)
111
112
                    else:
                        # Valid backend-block_size combination
113
114
115
                        backend = get_attn_backend(
                            16, torch.float16, None, block_size, use_mla=use_mla
                        )
116
                        expected = name
117
                        assert backend.get_name() == expected
118
                else:
119
120
121
                    backend = get_attn_backend(
                        16, torch.float16, None, block_size, use_mla=use_mla
                    )
122
                    expected = "ROCM_ATTN"
123
124
125
                    assert backend.get_name() == expected

        elif device == "cuda":
126
            with patch("vllm.platforms.current_platform", CudaPlatform()):
127
                capability = torch.cuda.get_device_capability()
128
                if use_mla:
129
130
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
131
                    #   and Blackwell GPUs (SM 10.x), V1 only
132
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
133
                    #   (SM 10.x), V1 only
134
135
136
137
138
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
139
                        if block_size != 128:
140
                            # CUTLASS_MLA only supports block_size == 128
141
                            pytest.skip("CUTLASS_MLA only supports block_size 128")
142
143
144
145
146
147
148
                        if capability[0] != 10:
                            pytest.skip("CUTLASS MLA is not supported on this platform")
                        backend = get_attn_backend(
                            576, torch.float16, None, block_size, use_mla=use_mla
                        )
                        expected = "CUTLASS_MLA"
                        assert backend.get_name() == expected
149
                    elif name == "FLASHINFER_MLA":
150
151
152
153
                        if capability[0] != 10:
                            pytest.skip(
                                "FlashInfer MLA is not supported on this platform"
                            )
154
                        if block_size not in [32, 64]:
155
156
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
157
158
                                "FlashInfer MLA only supports block_size 32 or 64"
                            )
159
160
161
162
163
                        backend = get_attn_backend(
                            576, torch.float16, None, block_size, use_mla=use_mla
                        )
                        expected = "FLASHINFER_MLA"
                        assert backend.get_name() == expected
164
165
166
167
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
168
169
170
                        from vllm.v1.attention.backends.mla.flashmla import (
                            is_flashmla_dense_supported,
                        )
171

172
173
174
175
176
177
178
179
180
181
182
183
                        is_supported, _ = is_flashmla_dense_supported()
                        if not is_supported:
                            pytest.skip("FlashMLA not supported on this platform")
                        backend = get_attn_backend(
                            576,
                            torch.float16,
                            None,
                            block_size,
                            use_mla=use_mla,
                        )
                        expected = name
                        assert backend.get_name() == expected
184
                    elif name == "FLASH_ATTN_MLA":
185
                        from vllm.v1.attention.backends.fa_utils import (
186
187
188
189
190
191
192
                            flash_attn_supports_mla,
                        )

                        if not flash_attn_supports_mla():
                            pytest.skip(
                                "FlashAttention MLA not supported on this platform"
                            )
193
                        backend = get_attn_backend(
194
                            576, torch.float16, None, block_size, use_mla=use_mla
195
                        )
196
197
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
198
                    else:
199
                        # TRITON_MLA or other fallback
200
                        backend = get_attn_backend(
201
                            576, torch.float16, None, block_size, use_mla=use_mla
202
                        )
203
                        expected = "TRITON_MLA"
204
                        assert backend.get_name() == expected
205
                elif name == "FLASHINFER":
206
                    backend = get_attn_backend(
207
                        64, torch.float16, None, block_size, use_mla=use_mla
208
                    )
209
                    expected = "FLASHINFER"
210
                    assert backend.get_name() == expected
211
                elif name == "FLASH_ATTN":
212
213
214
                    backend = get_attn_backend(
                        32, torch.float16, None, block_size, use_mla=use_mla
                    )
215
216
                    expected = "FLASH_ATTN"
                    assert backend.get_name() == expected
217

218

219
@pytest.mark.parametrize("device", ["cpu", "cuda"])
220
def test_fp32_fallback(device: str):
221
    """Test attention backend selection with fp32."""
222
223
    # Use default config (no backend specified)
    vllm_config = VllmConfig()
224

225
226
227
228
229
230
231
232
233
234
    with set_current_vllm_config(vllm_config):
        if device == "cpu":
            with patch("vllm.platforms.current_platform", CpuPlatform()):
                backend = get_attn_backend(16, torch.float32, None, 16)
            assert backend.get_name() == "CPU_ATTN"

        elif device == "cuda":
            with patch("vllm.platforms.current_platform", CudaPlatform()):
                backend = get_attn_backend(16, torch.float32, None, 16)
            assert backend.get_name() == "FLEX_ATTENTION"
235
236


237
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
238
    """Test FlashAttn validation."""
239
240
    pytest.skip(
        "Skipping as current backend selector does not "
241
        "handle fallbacks when a backend is explicitly set."
242
    )
243

244
245
    attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
    vllm_config = VllmConfig(attention_config=attention_config)
246

247
    with set_current_vllm_config(vllm_config):
248
        # Unsupported CUDA arch
249
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
250
        backend = get_attn_backend(16, torch.float16, None, 16)
251
        assert backend.get_name() != "FLASH_ATTN"
252

253
254
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
255

256
        # Unsupported data type
257
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
258
        assert backend.get_name() != "FLASH_ATTN"
259

260
        # Unsupported kv cache data type
261
        backend = get_attn_backend(16, torch.float16, "fp8", 16)
262
        assert backend.get_name() != "FLASH_ATTN"
263

264
        # Unsupported block size
265
        backend = get_attn_backend(16, torch.float16, None, 8)
266
        assert backend.get_name() != "FLASH_ATTN"
267
268
269

        # flash-attn is not installed
        import sys
270
271
272

        original_module = sys.modules.get("vllm_flash_attn")
        monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
273
        backend = get_attn_backend(16, torch.float16, None, 16)
274
        assert backend.get_name() != "FLASH_ATTN"
275

276
277
        # Restore the original module if it existed
        if original_module is not None:
278
            monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
279
        else:
280
            monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
281

282
        # Unsupported head size
283
        backend = get_attn_backend(17, torch.float16, None, 16)
284
        assert backend.get_name() != "FLASH_ATTN"
285
286


287
def test_invalid_backend():
288
    """Test that invalid attention backend names raise ValueError."""
289
    with (
290
        pytest.raises(ValueError),
291
    ):
292
293
        # Invalid backend name should raise ValueError when creating enum
        AttentionConfig(backend=AttentionBackendEnum["INVALID"])
294
295


296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@pytest.mark.parametrize("auto_value", ["auto", "AUTO", "Auto"])
def test_auto_backend_string(auto_value: str):
    """Test that 'auto' string value triggers automatic backend selection."""
    # Using "auto" should result in backend=None (automatic selection)
    attention_config = AttentionConfig(backend=auto_value)
    assert attention_config.backend is None


def test_auto_backend_selection_behavior():
    """Test that 'auto' backend behaves same as None (automatic selection)."""
    # Create config with explicit "auto"
    auto_config = AttentionConfig(backend="auto")

    # Create config with None (default)
    none_config = AttentionConfig(backend=None)

    # Both should have backend=None
    assert auto_config.backend is None
    assert none_config.backend is None

    # Both configs should result in the same automatic backend selection
    vllm_config_auto = VllmConfig(attention_config=auto_config)
    vllm_config_none = VllmConfig(attention_config=none_config)

    with (
        set_current_vllm_config(vllm_config_auto),
        patch("vllm.platforms.current_platform", CpuPlatform()),
    ):
        backend_auto = get_attn_backend(16, torch.float16, None, 16)

    _cached_get_attn_backend.cache_clear()

    with (
        set_current_vllm_config(vllm_config_none),
        patch("vllm.platforms.current_platform", CpuPlatform()),
    ):
        backend_none = get_attn_backend(16, torch.float16, None, 16)

    # Both should select the same backend
    assert backend_auto.get_name() == backend_none.get_name()


338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
@pytest.mark.parametrize(
    "backend_name,flash_attn_version,should_succeed",
    [
        ("FLASH_ATTN", 3, True),  # FA3 supports per-head quant scales
        ("FLASH_ATTN", 2, False),  # FA2 does not support per-head quant scales
        ("FLASHINFER", None, False),  # FlashInfer does not support
        ("FLEX_ATTENTION", None, False),  # Flex does not support
    ],
)
def test_per_head_quant_scales_backend_selection(
    backend_name: str, flash_attn_version: int | None, should_succeed: bool
):
    """Test backend selection when use_per_head_quant_scales=True."""
    # Clear cache to ensure fresh backend selection
    _cached_get_attn_backend.cache_clear()

    attention_config = AttentionConfig(
        backend=AttentionBackendEnum[backend_name],
        flash_attn_version=flash_attn_version,
    )
    vllm_config = VllmConfig(attention_config=attention_config)

    with (
        set_current_vllm_config(vllm_config),
        patch("vllm.platforms.current_platform", CudaPlatform()),
    ):
        if backend_name == "FLASH_ATTN" and flash_attn_version == 3:
            if not torch.cuda.is_available():
                pytest.skip("FA3 requires CUDA")
            capability = torch.cuda.get_device_capability()
            if capability[0] != 9:
                pytest.skip("FA3 is only supported on Hopper (SM 9.x) GPUs")

        if should_succeed:
            backend = get_attn_backend(
                head_size=128,
                dtype=torch.float16,
                kv_cache_dtype="fp8",
                block_size=64,
                use_per_head_quant_scales=True,
            )
            assert backend.get_name() == backend_name
        else:
            with pytest.raises(ValueError) as exc_info:
                get_attn_backend(
                    head_size=128,
                    dtype=torch.float16,
                    kv_cache_dtype="fp8",
                    block_size=64,
                    use_per_head_quant_scales=True,
                )
            assert backend_name in str(exc_info.value)