test_attention_selector.py 15.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
10
from vllm.platforms import current_platform
11
12
13
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
14
15
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
16
17


18
19
@pytest.fixture(autouse=True)
def clear_cache():
20
    """Clear lru cache to ensure each test case runs without caching."""
21
22
23
    _cached_get_attn_backend.cache_clear()


24
25
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
26
    "cuda": [
27
28
29
30
31
        "TRITON_MLA",
        "FLASHMLA",
        "FLASHINFER_MLA",
        "FLASH_ATTN_MLA",
        "CUTLASS_MLA",
32
    ],
33
34
35
36
37
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
38
    "cuda": ["FLASHINFER", "FLASH_ATTN"],
39
    "hip": ["ROCM_ATTN"],
40
    "cpu": ["CPU_ATTN"],
41
42
43
44
45
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
46
    # "cpu": [16]  # CPU uses fixed block size from test cases
47
    "cpu": [],  # FIXME(woosuk): Temporarily disable CPU tests
48
49
50
51
}


def generate_params():
52
    is_rocm = current_platform.is_rocm()
53
    params = []
54
    device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
55
    for use_mla in [True, False]:
56
        for device in device_list:
57
58
59
60
61
            backends = (
                DEVICE_MLA_BACKENDS[device]
                if use_mla
                else DEVICE_REGULAR_ATTN_BACKENDS[device]
            )
62
            for name in backends:
63
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
64
65
66
67
68
69
70
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
71
72
73
                            id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
                        )
                    )
74
75
76
    return params


77
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
78
def test_backend_selection(
79
    device: str,
80
    name: str,
81
82
    use_mla: bool,
    block_size: int,
83
):
84
    """Test attention backend selection with valid device-backend pairs."""
85
86
87
    # Create AttentionConfig with the specified backend
    attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
    vllm_config = VllmConfig(attention_config=attention_config)
88

89
    with set_current_vllm_config(vllm_config):
90
        if device == "cpu":
91
            with patch("vllm.platforms.current_platform", CpuPlatform()):
92
                backend = get_attn_backend(16, torch.float16, None, block_size)
93
            assert backend.get_name() == "CPU_ATTN"
94

95
        elif device == "hip":
96
            with patch("vllm.platforms.current_platform", RocmPlatform()):
97
                if use_mla:
98
99
100
101
102
103
104
105
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
106
                        with pytest.raises(ValueError):
107
                            get_attn_backend(
108
                                576, torch.float16, None, block_size, use_mla=use_mla
109
                            )
110
111
                    else:
                        # Valid backend-block_size combination
112
                        backend = get_attn_backend(
113
                            576, torch.float16, None, block_size, use_mla=use_mla
114
                        )
115
                        expected = name
116
                        assert backend.get_name() == expected
117
                else:
118
                    backend = get_attn_backend(
119
                        32, torch.float16, None, block_size, use_mla=use_mla
120
                    )
121
                    expected = "ROCM_ATTN"
122
123
124
                    assert backend.get_name() == expected

        elif device == "cuda":
125
            with patch("vllm.platforms.current_platform", CudaPlatform()):
126
                capability = torch.cuda.get_device_capability()
127
                if use_mla:
128
129
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
130
                    #   and Blackwell GPUs (SM 10.x), V1 only
131
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
132
                    #   (SM 10.x), V1 only
133
134
135
136
137
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
138
                        if block_size != 128:
139
                            # CUTLASS_MLA only supports block_size == 128
140
                            pytest.skip("CUTLASS_MLA only supports block_size 128")
141
142
143
144
145
146
147
                        if capability[0] != 10:
                            pytest.skip("CUTLASS MLA is not supported on this platform")
                        backend = get_attn_backend(
                            576, torch.float16, None, block_size, use_mla=use_mla
                        )
                        expected = "CUTLASS_MLA"
                        assert backend.get_name() == expected
148
                    elif name == "FLASHINFER_MLA":
149
150
151
152
                        if capability[0] != 10:
                            pytest.skip(
                                "FlashInfer MLA is not supported on this platform"
                            )
153
                        if block_size not in [32, 64]:
154
155
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
156
157
                                "FlashInfer MLA only supports block_size 32 or 64"
                            )
158
159
160
161
162
                        backend = get_attn_backend(
                            576, torch.float16, None, block_size, use_mla=use_mla
                        )
                        expected = "FLASHINFER_MLA"
                        assert backend.get_name() == expected
163
164
165
166
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
167
168
169
                        from vllm.v1.attention.backends.mla.flashmla import (
                            is_flashmla_dense_supported,
                        )
170

171
172
173
174
175
176
177
178
179
180
181
182
                        is_supported, _ = is_flashmla_dense_supported()
                        if not is_supported:
                            pytest.skip("FlashMLA not supported on this platform")
                        backend = get_attn_backend(
                            576,
                            torch.float16,
                            None,
                            block_size,
                            use_mla=use_mla,
                        )
                        expected = name
                        assert backend.get_name() == expected
183
                    elif name == "FLASH_ATTN_MLA":
184
                        from vllm.v1.attention.backends.fa_utils import (
185
186
187
188
189
190
191
                            flash_attn_supports_mla,
                        )

                        if not flash_attn_supports_mla():
                            pytest.skip(
                                "FlashAttention MLA not supported on this platform"
                            )
192
                        backend = get_attn_backend(
193
                            576, torch.float16, None, block_size, use_mla=use_mla
194
                        )
195
196
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
197
                    else:
198
                        # TRITON_MLA or other fallback
199
                        backend = get_attn_backend(
200
                            576, torch.float16, None, block_size, use_mla=use_mla
201
                        )
202
                        expected = "TRITON_MLA"
203
                        assert backend.get_name() == expected
204
                elif name == "FLASHINFER":
205
                    backend = get_attn_backend(
206
                        64, torch.float16, None, block_size, use_mla=use_mla
207
                    )
208
                    expected = "FLASHINFER"
209
                    assert backend.get_name() == expected
210
                elif name == "FLASH_ATTN":
211
212
213
                    backend = get_attn_backend(
                        32, torch.float16, None, block_size, use_mla=use_mla
                    )
214
215
                    expected = "FLASH_ATTN"
                    assert backend.get_name() == expected
216

217

218
@pytest.mark.parametrize("device", ["cpu", "cuda"])
219
def test_fp32_fallback(device: str):
220
    """Test attention backend selection with fp32."""
221
222
    # Use default config (no backend specified)
    vllm_config = VllmConfig()
223

224
225
226
227
228
229
230
231
232
233
    with set_current_vllm_config(vllm_config):
        if device == "cpu":
            with patch("vllm.platforms.current_platform", CpuPlatform()):
                backend = get_attn_backend(16, torch.float32, None, 16)
            assert backend.get_name() == "CPU_ATTN"

        elif device == "cuda":
            with patch("vllm.platforms.current_platform", CudaPlatform()):
                backend = get_attn_backend(16, torch.float32, None, 16)
            assert backend.get_name() == "FLEX_ATTENTION"
234
235


236
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
237
    """Test FlashAttn validation."""
238
239
    pytest.skip(
        "Skipping as current backend selector does not "
240
        "handle fallbacks when a backend is explicitly set."
241
    )
242

243
244
    attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
    vllm_config = VllmConfig(attention_config=attention_config)
245

246
    with set_current_vllm_config(vllm_config):
247
        # Unsupported CUDA arch
248
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
249
        backend = get_attn_backend(16, torch.float16, None, 16)
250
        assert backend.get_name() != "FLASH_ATTN"
251

252
253
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
254

255
        # Unsupported data type
256
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
257
        assert backend.get_name() != "FLASH_ATTN"
258

259
        # Unsupported kv cache data type
260
        backend = get_attn_backend(16, torch.float16, "fp8", 16)
261
        assert backend.get_name() != "FLASH_ATTN"
262

263
        # Unsupported block size
264
        backend = get_attn_backend(16, torch.float16, None, 8)
265
        assert backend.get_name() != "FLASH_ATTN"
266
267
268

        # flash-attn is not installed
        import sys
269
270
271

        original_module = sys.modules.get("vllm_flash_attn")
        monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
272
        backend = get_attn_backend(16, torch.float16, None, 16)
273
        assert backend.get_name() != "FLASH_ATTN"
274

275
276
        # Restore the original module if it existed
        if original_module is not None:
277
            monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
278
        else:
279
            monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
280

281
        # Unsupported head size
282
        backend = get_attn_backend(17, torch.float16, None, 16)
283
        assert backend.get_name() != "FLASH_ATTN"
284
285


286
def test_invalid_backend():
287
    """Test that invalid attention backend names raise ValueError."""
288
    with (
289
        pytest.raises(ValueError),
290
    ):
291
292
        # Invalid backend name should raise ValueError when creating enum
        AttentionConfig(backend=AttentionBackendEnum["INVALID"])
293
294


295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
@pytest.mark.parametrize("auto_value", ["auto", "AUTO", "Auto"])
def test_auto_backend_string(auto_value: str):
    """Test that 'auto' string value triggers automatic backend selection."""
    # Using "auto" should result in backend=None (automatic selection)
    attention_config = AttentionConfig(backend=auto_value)
    assert attention_config.backend is None


def test_auto_backend_selection_behavior():
    """Test that 'auto' backend behaves same as None (automatic selection)."""
    # Create config with explicit "auto"
    auto_config = AttentionConfig(backend="auto")

    # Create config with None (default)
    none_config = AttentionConfig(backend=None)

    # Both should have backend=None
    assert auto_config.backend is None
    assert none_config.backend is None

    # Both configs should result in the same automatic backend selection
    vllm_config_auto = VllmConfig(attention_config=auto_config)
    vllm_config_none = VllmConfig(attention_config=none_config)

    with (
        set_current_vllm_config(vllm_config_auto),
        patch("vllm.platforms.current_platform", CpuPlatform()),
    ):
        backend_auto = get_attn_backend(16, torch.float16, None, 16)

    _cached_get_attn_backend.cache_clear()

    with (
        set_current_vllm_config(vllm_config_none),
        patch("vllm.platforms.current_platform", CpuPlatform()),
    ):
        backend_none = get_attn_backend(16, torch.float16, None, 16)

    # Both should select the same backend
    assert backend_auto.get_name() == backend_none.get_name()


337
338
339
340
341
342
343
344
345
@pytest.mark.parametrize(
    "backend_name,flash_attn_version,should_succeed",
    [
        ("FLASH_ATTN", 3, True),  # FA3 supports per-head quant scales
        ("FLASH_ATTN", 2, False),  # FA2 does not support per-head quant scales
        ("FLASHINFER", None, False),  # FlashInfer does not support
        ("FLEX_ATTENTION", None, False),  # Flex does not support
    ],
)
346
347
348
349
@pytest.mark.skipif(
    current_platform.is_rocm(),
    reason="Attention backend FA3 is not supported on ROCm. This test can't succeed.",
)
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def test_per_head_quant_scales_backend_selection(
    backend_name: str, flash_attn_version: int | None, should_succeed: bool
):
    """Test backend selection when use_per_head_quant_scales=True."""
    # Clear cache to ensure fresh backend selection
    _cached_get_attn_backend.cache_clear()

    attention_config = AttentionConfig(
        backend=AttentionBackendEnum[backend_name],
        flash_attn_version=flash_attn_version,
    )
    vllm_config = VllmConfig(attention_config=attention_config)

    with (
        set_current_vllm_config(vllm_config),
        patch("vllm.platforms.current_platform", CudaPlatform()),
    ):
        if backend_name == "FLASH_ATTN" and flash_attn_version == 3:
            if not torch.cuda.is_available():
                pytest.skip("FA3 requires CUDA")
            capability = torch.cuda.get_device_capability()
            if capability[0] != 9:
                pytest.skip("FA3 is only supported on Hopper (SM 9.x) GPUs")

        if should_succeed:
            backend = get_attn_backend(
                head_size=128,
                dtype=torch.float16,
                kv_cache_dtype="fp8",
                block_size=64,
                use_per_head_quant_scales=True,
            )
            assert backend.get_name() == backend_name
        else:
            with pytest.raises(ValueError) as exc_info:
                get_attn_backend(
                    head_size=128,
                    dtype=torch.float16,
                    kv_cache_dtype="fp8",
                    block_size=64,
                    use_per_head_quant_scales=True,
                )
            assert backend_name in str(exc_info.value)