test_attention_selector.py 15.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
10
11
12
13
14
from vllm.config import (
    AttentionConfig,
    CacheConfig,
    VllmConfig,
    set_current_vllm_config,
)
15
from vllm.platforms import current_platform
16
17
18
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
19
20
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
21
22


23
24
@pytest.fixture(autouse=True)
def clear_cache():
25
    """Clear lru cache to ensure each test case runs without caching."""
26
27
28
    _cached_get_attn_backend.cache_clear()


29
30
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
31
    "cuda": [
32
33
34
35
36
        "TRITON_MLA",
        "FLASHMLA",
        "FLASHINFER_MLA",
        "FLASH_ATTN_MLA",
        "CUTLASS_MLA",
37
    ],
38
39
40
41
42
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
43
    "cuda": ["FLASHINFER", "FLASH_ATTN"],
44
    "hip": ["ROCM_ATTN"],
45
    "cpu": ["CPU_ATTN"],
46
47
48
49
50
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
51
    # "cpu": [16]  # CPU uses fixed block size from test cases
52
    "cpu": [],  # FIXME(woosuk): Temporarily disable CPU tests
53
54
55
56
}


def generate_params():
57
    is_rocm = current_platform.is_rocm()
58
    params = []
59
    device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
60
    for use_mla in [True, False]:
61
        for device in device_list:
62
63
64
65
66
            backends = (
                DEVICE_MLA_BACKENDS[device]
                if use_mla
                else DEVICE_REGULAR_ATTN_BACKENDS[device]
            )
67
            for name in backends:
68
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
69
70
71
72
73
74
75
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
76
77
78
                            id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
                        )
                    )
79
80
81
    return params


82
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
83
def test_backend_selection(
84
    device: str,
85
    name: str,
86
87
    use_mla: bool,
    block_size: int,
88
):
89
    """Test attention backend selection with valid device-backend pairs."""
90
91
    # Create AttentionConfig with the specified backend
    attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
92
93
94
95
    cache_config = CacheConfig(block_size=block_size)
    vllm_config = VllmConfig(
        attention_config=attention_config, cache_config=cache_config
    )
96

97
    with set_current_vllm_config(vllm_config):
98
        if device == "cpu":
99
            with patch("vllm.platforms.current_platform", CpuPlatform()):
100
                backend = get_attn_backend(16, torch.float16, None)
101
            assert backend.get_name() == "CPU_ATTN"
102

103
        elif device == "hip":
104
            with patch("vllm.platforms.current_platform", RocmPlatform()):
105
                if use_mla:
106
107
108
109
110
111
112
113
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
114
                        with pytest.raises(ValueError):
115
                            get_attn_backend(576, torch.float16, None, use_mla=use_mla)
116
117
                    else:
                        # Valid backend-block_size combination
118
                        backend = get_attn_backend(
119
                            576, torch.float16, None, use_mla=use_mla
120
                        )
121
                        expected = name
122
                        assert backend.get_name() == expected
123
                else:
124
                    backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
125
                    expected = "ROCM_ATTN"
126
127
128
                    assert backend.get_name() == expected

        elif device == "cuda":
129
            with patch("vllm.platforms.current_platform", CudaPlatform()):
130
                capability = torch.cuda.get_device_capability()
131
                if use_mla:
132
133
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
134
                    #   and Blackwell GPUs (SM 10.x), V1 only
135
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
136
                    #   (SM 10.x), V1 only
137
138
139
140
141
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
142
                        if block_size != 128:
143
                            # CUTLASS_MLA only supports block_size == 128
144
                            pytest.skip("CUTLASS_MLA only supports block_size 128")
145
146
147
                        if capability[0] != 10:
                            pytest.skip("CUTLASS MLA is not supported on this platform")
                        backend = get_attn_backend(
148
                            576, torch.float16, None, use_mla=use_mla
149
150
151
                        )
                        expected = "CUTLASS_MLA"
                        assert backend.get_name() == expected
152
                    elif name == "FLASHINFER_MLA":
153
154
155
156
                        if capability[0] != 10:
                            pytest.skip(
                                "FlashInfer MLA is not supported on this platform"
                            )
157
                        if block_size not in [32, 64]:
158
159
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
160
161
                                "FlashInfer MLA only supports block_size 32 or 64"
                            )
162
                        backend = get_attn_backend(
163
                            576, torch.float16, None, use_mla=use_mla
164
165
166
                        )
                        expected = "FLASHINFER_MLA"
                        assert backend.get_name() == expected
167
168
169
170
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
171
172
173
                        from vllm.v1.attention.backends.mla.flashmla import (
                            is_flashmla_dense_supported,
                        )
174

175
176
177
178
179
180
181
182
183
184
185
                        is_supported, _ = is_flashmla_dense_supported()
                        if not is_supported:
                            pytest.skip("FlashMLA not supported on this platform")
                        backend = get_attn_backend(
                            576,
                            torch.float16,
                            None,
                            use_mla=use_mla,
                        )
                        expected = name
                        assert backend.get_name() == expected
186
                    elif name == "FLASH_ATTN_MLA":
187
                        from vllm.v1.attention.backends.fa_utils import (
188
189
190
191
192
193
194
                            flash_attn_supports_mla,
                        )

                        if not flash_attn_supports_mla():
                            pytest.skip(
                                "FlashAttention MLA not supported on this platform"
                            )
195
                        backend = get_attn_backend(
196
                            576, torch.float16, None, use_mla=use_mla
197
                        )
198
199
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
200
                    else:
201
                        # TRITON_MLA or other fallback
202
                        backend = get_attn_backend(
203
                            576, torch.float16, None, use_mla=use_mla
204
                        )
205
                        expected = "TRITON_MLA"
206
                        assert backend.get_name() == expected
207
                elif name == "FLASHINFER":
208
                    backend = get_attn_backend(64, torch.float16, None, use_mla=use_mla)
209
                    expected = "FLASHINFER"
210
                    assert backend.get_name() == expected
211
                elif name == "FLASH_ATTN":
212
                    backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
213
214
                    expected = "FLASH_ATTN"
                    assert backend.get_name() == expected
215

216

217
@pytest.mark.parametrize("device", ["cpu", "cuda"])
218
def test_fp32_fallback(device: str):
219
    """Test attention backend selection with fp32."""
220
221
    # Use default config (no backend specified)
    vllm_config = VllmConfig()
222

223
224
225
    with set_current_vllm_config(vllm_config):
        if device == "cpu":
            with patch("vllm.platforms.current_platform", CpuPlatform()):
226
                backend = get_attn_backend(16, torch.float32, None)
227
228
229
230
            assert backend.get_name() == "CPU_ATTN"

        elif device == "cuda":
            with patch("vllm.platforms.current_platform", CudaPlatform()):
231
                backend = get_attn_backend(16, torch.float32, None)
232
            assert backend.get_name() == "FLEX_ATTENTION"
233
234


235
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
236
    """Test FlashAttn validation."""
237
238
    pytest.skip(
        "Skipping as current backend selector does not "
239
        "handle fallbacks when a backend is explicitly set."
240
    )
241

242
    attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
243
244
245
246
    cache_config = CacheConfig(block_size=16)
    vllm_config = VllmConfig(
        attention_config=attention_config, cache_config=cache_config
    )
247

248
    with set_current_vllm_config(vllm_config):
249
        # Unsupported CUDA arch
250
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
251
        backend = get_attn_backend(16, torch.float16, None)
252
        assert backend.get_name() != "FLASH_ATTN"
253

254
255
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
256

257
        # Unsupported data type
258
        backend = get_attn_backend(16, torch.float8_e4m3fn, None)
259
        assert backend.get_name() != "FLASH_ATTN"
260

261
        # Unsupported kv cache data type
262
        backend = get_attn_backend(16, torch.float16, "fp8")
263
        assert backend.get_name() != "FLASH_ATTN"
264

265
        # Unsupported block size
266
267
        vllm_config.cache_config.block_size = 8
        backend = get_attn_backend(16, torch.float16, None)
268
        assert backend.get_name() != "FLASH_ATTN"
269
270
271

        # flash-attn is not installed
        import sys
272

273
        vllm_config.cache_config.block_size = 16
274
275
        original_module = sys.modules.get("vllm_flash_attn")
        monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
276
        backend = get_attn_backend(16, torch.float16, None)
277
        assert backend.get_name() != "FLASH_ATTN"
278

279
280
        # Restore the original module if it existed
        if original_module is not None:
281
            monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
282
        else:
283
            monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
284

285
        # Unsupported head size
286
        backend = get_attn_backend(17, torch.float16, None)
287
        assert backend.get_name() != "FLASH_ATTN"
288
289


290
def test_invalid_backend():
291
    """Test that invalid attention backend names raise ValueError."""
292
    with (
293
        pytest.raises(ValueError),
294
    ):
295
296
        # Invalid backend name should raise ValueError when creating enum
        AttentionConfig(backend=AttentionBackendEnum["INVALID"])
297
298


299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
@pytest.mark.parametrize("auto_value", ["auto", "AUTO", "Auto"])
def test_auto_backend_string(auto_value: str):
    """Test that 'auto' string value triggers automatic backend selection."""
    # Using "auto" should result in backend=None (automatic selection)
    attention_config = AttentionConfig(backend=auto_value)
    assert attention_config.backend is None


def test_auto_backend_selection_behavior():
    """Test that 'auto' backend behaves same as None (automatic selection)."""
    # Create config with explicit "auto"
    auto_config = AttentionConfig(backend="auto")

    # Create config with None (default)
    none_config = AttentionConfig(backend=None)

    # Both should have backend=None
    assert auto_config.backend is None
    assert none_config.backend is None

    # Both configs should result in the same automatic backend selection
    vllm_config_auto = VllmConfig(attention_config=auto_config)
    vllm_config_none = VllmConfig(attention_config=none_config)

    with (
        set_current_vllm_config(vllm_config_auto),
        patch("vllm.platforms.current_platform", CpuPlatform()),
    ):
327
        backend_auto = get_attn_backend(16, torch.float16, None)
328
329
330
331
332
333
334

    _cached_get_attn_backend.cache_clear()

    with (
        set_current_vllm_config(vllm_config_none),
        patch("vllm.platforms.current_platform", CpuPlatform()),
    ):
335
        backend_none = get_attn_backend(16, torch.float16, None)
336
337
338
339
340

    # Both should select the same backend
    assert backend_auto.get_name() == backend_none.get_name()


341
342
343
344
345
346
347
348
349
@pytest.mark.parametrize(
    "backend_name,flash_attn_version,should_succeed",
    [
        ("FLASH_ATTN", 3, True),  # FA3 supports per-head quant scales
        ("FLASH_ATTN", 2, False),  # FA2 does not support per-head quant scales
        ("FLASHINFER", None, False),  # FlashInfer does not support
        ("FLEX_ATTENTION", None, False),  # Flex does not support
    ],
)
350
351
352
353
@pytest.mark.skipif(
    current_platform.is_rocm(),
    reason="Attention backend FA3 is not supported on ROCm. This test can't succeed.",
)
354
355
356
357
358
359
360
361
362
363
364
def test_per_head_quant_scales_backend_selection(
    backend_name: str, flash_attn_version: int | None, should_succeed: bool
):
    """Test backend selection when use_per_head_quant_scales=True."""
    # Clear cache to ensure fresh backend selection
    _cached_get_attn_backend.cache_clear()

    attention_config = AttentionConfig(
        backend=AttentionBackendEnum[backend_name],
        flash_attn_version=flash_attn_version,
    )
365
366
367
368
    cache_config = CacheConfig(block_size=64)
    vllm_config = VllmConfig(
        attention_config=attention_config, cache_config=cache_config
    )
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397

    with (
        set_current_vllm_config(vllm_config),
        patch("vllm.platforms.current_platform", CudaPlatform()),
    ):
        if backend_name == "FLASH_ATTN" and flash_attn_version == 3:
            if not torch.cuda.is_available():
                pytest.skip("FA3 requires CUDA")
            capability = torch.cuda.get_device_capability()
            if capability[0] != 9:
                pytest.skip("FA3 is only supported on Hopper (SM 9.x) GPUs")

        if should_succeed:
            backend = get_attn_backend(
                head_size=128,
                dtype=torch.float16,
                kv_cache_dtype="fp8",
                use_per_head_quant_scales=True,
            )
            assert backend.get_name() == backend_name
        else:
            with pytest.raises(ValueError) as exc_info:
                get_attn_backend(
                    head_size=128,
                    dtype=torch.float16,
                    kv_cache_dtype="fp8",
                    use_per_head_quant_scales=True,
                )
            assert backend_name in str(exc_info.value)