test_attention_selector.py 16.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
10
11
12
13
14
from vllm.config import (
    AttentionConfig,
    CacheConfig,
    VllmConfig,
    set_current_vllm_config,
)
15
from vllm.platforms import current_platform
16
from vllm.platforms.cpu import CpuPlatform
17
18
19
20
21
22
23
24
25
26
27
28
29

# CudaPlatform and RocmPlatform import their respective compiled C extensions
# at module level, raising ModuleNotFoundError on incompatible builds.
try:
    from vllm.platforms.cuda import CudaPlatform
except (ImportError, ModuleNotFoundError):
    CudaPlatform = None

try:
    from vllm.platforms.rocm import RocmPlatform
except (ImportError, ModuleNotFoundError):
    RocmPlatform = None

30
31
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
32
33


34
35
@pytest.fixture(autouse=True)
def clear_cache():
36
    """Clear lru cache to ensure each test case runs without caching."""
37
38
39
    _cached_get_attn_backend.cache_clear()


40
41
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
42
    "cuda": [
43
44
45
46
47
        "TRITON_MLA",
        "FLASHMLA",
        "FLASHINFER_MLA",
        "FLASH_ATTN_MLA",
        "CUTLASS_MLA",
48
    ],
49
50
51
52
53
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
54
    "cuda": ["FLASHINFER", "FLASH_ATTN"],
55
    "hip": ["ROCM_ATTN"],
56
    "cpu": ["CPU_ATTN"],
57
58
59
60
61
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
62
    # "cpu": [16]  # CPU uses fixed block size from test cases
63
    "cpu": [],  # FIXME(woosuk): Temporarily disable CPU tests
64
65
66
67
}


def generate_params():
68
    is_rocm = current_platform.is_rocm()
69
    params = []
70
    device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
71
    for use_mla in [True, False]:
72
        for device in device_list:
73
74
75
76
77
            backends = (
                DEVICE_MLA_BACKENDS[device]
                if use_mla
                else DEVICE_REGULAR_ATTN_BACKENDS[device]
            )
78
            for name in backends:
79
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
80
81
82
83
84
85
86
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
87
88
89
                            id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
                        )
                    )
90
91
92
    return params


93
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
94
def test_backend_selection(
95
    device: str,
96
    name: str,
97
98
    use_mla: bool,
    block_size: int,
99
):
100
    """Test attention backend selection with valid device-backend pairs."""
101
102
    # Create AttentionConfig with the specified backend
    attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
103
104
105
106
    cache_config = CacheConfig(block_size=block_size)
    vllm_config = VllmConfig(
        attention_config=attention_config, cache_config=cache_config
    )
107

108
    with set_current_vllm_config(vllm_config):
109
        if device == "cpu":
110
            with patch("vllm.platforms.current_platform", CpuPlatform()):
111
                backend = get_attn_backend(16, torch.float16, None)
112
            assert backend.get_name() == "CPU_ATTN"
113

114
        elif device == "hip":
115
116
            if RocmPlatform is None:
                pytest.skip("RocmPlatform not available")
117
            with patch("vllm.platforms.current_platform", RocmPlatform()):
118
                if use_mla:
119
120
121
122
123
124
125
126
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
127
                        with pytest.raises(ValueError):
128
                            get_attn_backend(576, torch.float16, None, use_mla=use_mla)
129
130
                    else:
                        # Valid backend-block_size combination
131
                        backend = get_attn_backend(
132
                            576, torch.float16, None, use_mla=use_mla
133
                        )
134
                        expected = name
135
                        assert backend.get_name() == expected
136
                else:
137
                    backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
138
                    expected = "ROCM_ATTN"
139
140
141
                    assert backend.get_name() == expected

        elif device == "cuda":
142
143
            if CudaPlatform is None:
                pytest.skip("CudaPlatform not available")
144
            with patch("vllm.platforms.current_platform", CudaPlatform()):
145
                capability = torch.cuda.get_device_capability()
146
                if use_mla:
147
148
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
149
                    #   and Blackwell GPUs (SM 10.x), V1 only
150
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
151
                    #   (SM 10.x), V1 only
152
153
154
155
156
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
157
                        if block_size != 128:
158
                            # CUTLASS_MLA only supports block_size == 128
159
                            pytest.skip("CUTLASS_MLA only supports block_size 128")
160
161
162
                        if capability[0] != 10:
                            pytest.skip("CUTLASS MLA is not supported on this platform")
                        backend = get_attn_backend(
163
                            576, torch.float16, None, use_mla=use_mla
164
165
166
                        )
                        expected = "CUTLASS_MLA"
                        assert backend.get_name() == expected
167
                    elif name == "FLASHINFER_MLA":
168
169
170
171
                        if capability[0] != 10:
                            pytest.skip(
                                "FlashInfer MLA is not supported on this platform"
                            )
172
                        if block_size not in [32, 64]:
173
174
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
175
176
                                "FlashInfer MLA only supports block_size 32 or 64"
                            )
177
                        backend = get_attn_backend(
178
                            576, torch.float16, None, use_mla=use_mla
179
180
181
                        )
                        expected = "FLASHINFER_MLA"
                        assert backend.get_name() == expected
182
183
184
185
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
186
187
188
                        from vllm.v1.attention.backends.mla.flashmla import (
                            is_flashmla_dense_supported,
                        )
189

190
191
192
193
194
195
196
197
198
199
200
                        is_supported, _ = is_flashmla_dense_supported()
                        if not is_supported:
                            pytest.skip("FlashMLA not supported on this platform")
                        backend = get_attn_backend(
                            576,
                            torch.float16,
                            None,
                            use_mla=use_mla,
                        )
                        expected = name
                        assert backend.get_name() == expected
201
                    elif name == "FLASH_ATTN_MLA":
202
                        from vllm.v1.attention.backends.fa_utils import (
203
204
205
206
207
208
209
                            flash_attn_supports_mla,
                        )

                        if not flash_attn_supports_mla():
                            pytest.skip(
                                "FlashAttention MLA not supported on this platform"
                            )
210
                        backend = get_attn_backend(
211
                            576, torch.float16, None, use_mla=use_mla
212
                        )
213
214
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
215
                    else:
216
                        # TRITON_MLA or other fallback
217
                        backend = get_attn_backend(
218
                            576, torch.float16, None, use_mla=use_mla
219
                        )
220
                        expected = "TRITON_MLA"
221
                        assert backend.get_name() == expected
222
                elif name == "FLASHINFER":
223
                    backend = get_attn_backend(64, torch.float16, None, use_mla=use_mla)
224
                    expected = "FLASHINFER"
225
                    assert backend.get_name() == expected
226
                elif name == "FLASH_ATTN":
227
                    backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
228
229
                    expected = "FLASH_ATTN"
                    assert backend.get_name() == expected
230

231

232
@pytest.mark.parametrize("device", ["cpu", "cuda", "hip"])
233
def test_fp32_fallback(device: str):
234
    """Test attention backend selection with fp32."""
235
236
    # Use default config (no backend specified)
    vllm_config = VllmConfig()
237

238
239
240
    with set_current_vllm_config(vllm_config):
        if device == "cpu":
            with patch("vllm.platforms.current_platform", CpuPlatform()):
241
                backend = get_attn_backend(16, torch.float32, None)
242
243
244
            assert backend.get_name() == "CPU_ATTN"

        elif device == "cuda":
245
246
            if CudaPlatform is None:
                pytest.skip("CudaPlatform not available")
247
            with patch("vllm.platforms.current_platform", CudaPlatform()):
248
                backend = get_attn_backend(16, torch.float32, None)
249
            assert backend.get_name() == "FLEX_ATTENTION"
250

251
252
253
254
255
256
257
258
259
260
261
262
263
        elif device == "hip":
            if RocmPlatform is None:
                pytest.skip("RocmPlatform not available")
            # ROCm backends do not support head_size=16 (minimum is 32).
            # No known HuggingFace transformer model uses head_size=16.
            # Revisit if a real model with this head size is identified
            # and accuracy-tested.
            with (
                patch("vllm.platforms.current_platform", RocmPlatform()),
                pytest.raises(ValueError, match="No valid attention backend"),
            ):
                get_attn_backend(16, torch.float32, None)

264

265
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
266
    """Test FlashAttn validation."""
267
268
    pytest.skip(
        "Skipping as current backend selector does not "
269
        "handle fallbacks when a backend is explicitly set."
270
    )
271

272
    attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
273
274
275
276
    cache_config = CacheConfig(block_size=16)
    vllm_config = VllmConfig(
        attention_config=attention_config, cache_config=cache_config
    )
277

278
    with set_current_vllm_config(vllm_config):
279
        # Unsupported CUDA arch
280
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
281
        backend = get_attn_backend(16, torch.float16, None)
282
        assert backend.get_name() != "FLASH_ATTN"
283

284
285
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
286

287
        # Unsupported data type
288
        backend = get_attn_backend(16, torch.float8_e4m3fn, None)
289
        assert backend.get_name() != "FLASH_ATTN"
290

291
        # Unsupported kv cache data type
292
        backend = get_attn_backend(16, torch.float16, "fp8")
293
        assert backend.get_name() != "FLASH_ATTN"
294

295
        # Unsupported block size
296
297
        vllm_config.cache_config.block_size = 8
        backend = get_attn_backend(16, torch.float16, None)
298
        assert backend.get_name() != "FLASH_ATTN"
299
300
301

        # flash-attn is not installed
        import sys
302

303
        vllm_config.cache_config.block_size = 16
304
305
        original_module = sys.modules.get("vllm_flash_attn")
        monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
306
        backend = get_attn_backend(16, torch.float16, None)
307
        assert backend.get_name() != "FLASH_ATTN"
308

309
310
        # Restore the original module if it existed
        if original_module is not None:
311
            monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
312
        else:
313
            monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
314

315
        # Unsupported head size
316
        backend = get_attn_backend(17, torch.float16, None)
317
        assert backend.get_name() != "FLASH_ATTN"
318
319


320
def test_invalid_backend():
321
    """Test that invalid attention backend names raise ValueError."""
322
    with (
323
        pytest.raises(ValueError),
324
    ):
325
326
        # Invalid backend name should raise ValueError when creating enum
        AttentionConfig(backend=AttentionBackendEnum["INVALID"])
327
328


329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
@pytest.mark.parametrize("auto_value", ["auto", "AUTO", "Auto"])
def test_auto_backend_string(auto_value: str):
    """Test that 'auto' string value triggers automatic backend selection."""
    # Using "auto" should result in backend=None (automatic selection)
    attention_config = AttentionConfig(backend=auto_value)
    assert attention_config.backend is None


def test_auto_backend_selection_behavior():
    """Test that 'auto' backend behaves same as None (automatic selection)."""
    # Create config with explicit "auto"
    auto_config = AttentionConfig(backend="auto")

    # Create config with None (default)
    none_config = AttentionConfig(backend=None)

    # Both should have backend=None
    assert auto_config.backend is None
    assert none_config.backend is None

    # Both configs should result in the same automatic backend selection
    vllm_config_auto = VllmConfig(attention_config=auto_config)
    vllm_config_none = VllmConfig(attention_config=none_config)

    with (
        set_current_vllm_config(vllm_config_auto),
        patch("vllm.platforms.current_platform", CpuPlatform()),
    ):
357
        backend_auto = get_attn_backend(16, torch.float16, None)
358
359
360
361
362
363
364

    _cached_get_attn_backend.cache_clear()

    with (
        set_current_vllm_config(vllm_config_none),
        patch("vllm.platforms.current_platform", CpuPlatform()),
    ):
365
        backend_none = get_attn_backend(16, torch.float16, None)
366
367
368
369
370

    # Both should select the same backend
    assert backend_auto.get_name() == backend_none.get_name()


371
372
373
374
375
376
377
378
379
@pytest.mark.parametrize(
    "backend_name,flash_attn_version,should_succeed",
    [
        ("FLASH_ATTN", 3, True),  # FA3 supports per-head quant scales
        ("FLASH_ATTN", 2, False),  # FA2 does not support per-head quant scales
        ("FLASHINFER", None, False),  # FlashInfer does not support
        ("FLEX_ATTENTION", None, False),  # Flex does not support
    ],
)
380
381
382
383
@pytest.mark.skipif(
    current_platform.is_rocm(),
    reason="Attention backend FA3 is not supported on ROCm. This test can't succeed.",
)
384
385
386
387
388
389
390
391
392
393
394
def test_per_head_quant_scales_backend_selection(
    backend_name: str, flash_attn_version: int | None, should_succeed: bool
):
    """Test backend selection when use_per_head_quant_scales=True."""
    # Clear cache to ensure fresh backend selection
    _cached_get_attn_backend.cache_clear()

    attention_config = AttentionConfig(
        backend=AttentionBackendEnum[backend_name],
        flash_attn_version=flash_attn_version,
    )
395
396
397
398
    cache_config = CacheConfig(block_size=64)
    vllm_config = VllmConfig(
        attention_config=attention_config, cache_config=cache_config
    )
399

400
401
    if CudaPlatform is None:
        pytest.skip("CudaPlatform not available")
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    with (
        set_current_vllm_config(vllm_config),
        patch("vllm.platforms.current_platform", CudaPlatform()),
    ):
        if backend_name == "FLASH_ATTN" and flash_attn_version == 3:
            if not torch.cuda.is_available():
                pytest.skip("FA3 requires CUDA")
            capability = torch.cuda.get_device_capability()
            if capability[0] != 9:
                pytest.skip("FA3 is only supported on Hopper (SM 9.x) GPUs")

        if should_succeed:
            backend = get_attn_backend(
                head_size=128,
                dtype=torch.float16,
                kv_cache_dtype="fp8",
                use_per_head_quant_scales=True,
            )
            assert backend.get_name() == backend_name
        else:
            with pytest.raises(ValueError) as exc_info:
                get_attn_backend(
                    head_size=128,
                    dtype=torch.float16,
                    kv_cache_dtype="fp8",
                    use_per_head_quant_scales=True,
                )
            assert backend_name in str(exc_info.value)