test_attention_selector.py 16.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
10
11
12
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
13
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
14
15


16
17
18
19
20
21
22
@pytest.fixture(autouse=True)
def clear_cache():
    """Clear lru cache to ensure each test case runs without caching.
    """
    _cached_get_attn_backend.cache_clear()


23
24
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
25
26
27
28
    "cuda": [
        "TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
        "CUTLASS_MLA"
    ],
29
30
31
32
33
34
35
36
37
38
39
40
41
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
    "cuda": ["XFORMERS", "FLASHINFER"],
    "hip": ["ROCM_FLASH"],
    "cpu": ["TORCH_SDPA"],
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
42
43
    # "cpu": [16]  # CPU uses fixed block size from test cases
    "cpu": []  # FIXME(woosuk): Temporarily disable CPU tests
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
}


def generate_params():
    params = []
    for use_mla in [True, False]:
        for device in ["cuda", "hip", "cpu"]:
            backends = DEVICE_MLA_BACKENDS[
                device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
            for name in backends:
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
                    16
                ]
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
                            id=
                            f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
                        ))
    return params


@pytest.mark.parametrize("device, name, use_mla, block_size",
                         generate_params())
72
@pytest.mark.parametrize("use_v1", [True, False])
73
def test_env(
74
    device: str,
75
    name: str,
76
77
    use_mla: bool,
    block_size: int,
78
79
80
    use_v1: bool,
    monkeypatch: pytest.MonkeyPatch,
):
81
    """Test attention backend selection with valid device-backend pairs."""
82
83
84
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
        m.setenv(STR_BACKEND_ENV_VAR, name)
85
        m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
86

87
88
89
        if name == "FLASHINFER" and not use_v1:
            pytest.skip("FlashInfer backend is only available on V1 engine")

90
        if device == "cpu":
91
92
93
            if not use_v1:
                pytest.skip("CPU backend only supports V1")

94
95
            with patch("vllm.attention.selector.current_platform",
                       CpuPlatform()):
96
97
                backend = get_attn_backend(16, torch.float16, None, block_size,
                                           False)
98
            assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
99

100
        elif device == "hip":
101
            with patch("vllm.attention.selector.current_platform",
102
                       RocmPlatform()):
103
                if use_mla:
104
105
106
107
108
109
110
111
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
112
113
114
                        with pytest.raises(ValueError) as exc_info:
                            get_attn_backend(16,
                                             torch.float16,
115
                                             None,
116
117
118
119
120
                                             block_size,
                                             False,
                                             use_mla=use_mla)
                        assert f"The selected backend, {name}" in str(
                            exc_info.value)
121
122
123
124
125
                    elif name == "ROCM_AITER_MLA" and block_size != 1:
                        # ROCM_AITER_MLA only supports block_size == 1
                        with pytest.raises(ValueError) as exc_info:
                            get_attn_backend(16,
                                             torch.float16,
126
                                             None,
127
128
129
130
131
132
133
134
135
                                             block_size,
                                             False,
                                             use_mla=use_mla)
                        assert f"The selected backend, {name}" in str(
                            exc_info.value)
                    else:
                        # Valid backend-block_size combination
                        backend = get_attn_backend(16,
                                                   torch.float16,
136
                                                   None,
137
138
139
140
141
                                                   block_size,
                                                   False,
                                                   use_mla=use_mla)
                        expected = f"{name}_VLLM_V1" if use_v1 else name
                        assert backend.get_name() == expected
142
143
144
                else:
                    backend = get_attn_backend(16,
                                               torch.float16,
145
                                               None,
146
147
148
149
150
151
152
153
154
155
                                               block_size,
                                               False,
                                               use_mla=use_mla)
                    expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
                    assert backend.get_name() == expected

        elif device == "cuda":
            with patch("vllm.attention.selector.current_platform",
                       CudaPlatform()):
                if use_mla:
156
157
158
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
                    #   and Blackwell GPUs (SM 10.0), V1 only
159
160
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
                    #   (SM 10.0+), V1 only
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
                        if not use_v1:
                            # CUTLASS_MLA only supported on V1 engine
                            pytest.skip(
                                "CUTLASS_MLA only supported on V1 engine")
                        elif block_size != 128:
                            # CUTLASS_MLA only supports block_size == 128
                            pytest.skip(
                                "CUTLASS_MLA only supports block_size 128")
                        else:
                            backend = get_attn_backend(16,
                                                       torch.float16,
177
                                                       None,
178
179
180
181
182
                                                       block_size,
                                                       False,
                                                       use_mla=use_mla)
                            expected = "CUTLASS_MLA_VLLM_V1"
                            assert backend.get_name() == expected
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
                    elif name == "FLASHINFER_MLA":
                        if not use_v1:
                            # FlashInfer MLA only supported on V1 engine
                            pytest.skip(
                                "FlashInfer MLA only supported on V1 engine")
                        elif block_size not in [32, 64]:
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
                                "FlashInfer MLA only supports block_size 32 "
                                "or 64")
                        else:
                            backend = get_attn_backend(16,
                                                       torch.float16,
                                                       None,
                                                       block_size,
                                                       False,
                                                       use_mla=use_mla)
                            expected = "FLASHINFER_MLA"
                            assert backend.get_name() == expected
202
203
204
205
206
207
208
209
210
211
212
213
214
215
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
                        else:
                            from vllm.attention.backends.flashmla import (
                                is_flashmla_supported)
                            is_supported, _ = is_flashmla_supported()
                            if not is_supported:
                                pytest.skip(
                                    "FlashMLA not supported on this platform")
                            else:
                                backend = get_attn_backend(16,
                                                           torch.float16,
216
                                                           None,
217
218
219
220
221
222
223
224
225
226
227
                                                           block_size,
                                                           False,
                                                           use_mla=use_mla)
                                expected = f"{name}_VLLM_V1" if use_v1 else name
                                assert backend.get_name() == expected
                    elif name == "FLASH_ATTN_MLA":
                        if not use_v1:
                            # FlashAttention MLA only supported on V1 engine
                            pytest.skip(
                                "FlashAttention MLA only supported on V1 engine"
                            )
228
229
230
                        else:
                            backend = get_attn_backend(16,
                                                       torch.float16,
231
                                                       None,
232
233
234
                                                       block_size,
                                                       False,
                                                       use_mla=use_mla)
235
                            expected = "FLASH_ATTN_MLA"
236
237
                            assert backend.get_name() == expected
                    else:
238
                        # TRITON_MLA or other fallback
239
240
                        backend = get_attn_backend(16,
                                                   torch.float16,
241
                                                   None,
242
243
244
245
246
247
                                                   block_size,
                                                   False,
                                                   use_mla=use_mla)
                        expected = ("TRITON_MLA_VLLM_V1"
                                    if use_v1 else "TRITON_MLA")
                        assert backend.get_name() == expected
248
249
250
                elif name == "FLASHINFER":
                    backend = get_attn_backend(16,
                                               torch.float16,
251
                                               None,
252
253
254
255
256
                                               block_size,
                                               False,
                                               use_mla=use_mla)
                    expected = "FLASHINFER_VLLM_V1" if use_v1 else name
                    assert backend.get_name() == expected
257
                else:
258
                    backend = get_attn_backend(32,
259
                                               torch.float16,
260
                                               None,
261
262
263
264
265
                                               block_size,
                                               False,
                                               use_mla=use_mla)
                    expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
                    assert backend.get_name() == expected
266

267
268
269
                    if use_v1:
                        backend = get_attn_backend(16,
                                                   torch.float16,
270
                                                   None,
271
272
273
274
275
276
277
                                                   block_size,
                                                   False,
                                                   use_mla=use_mla)
                        assert backend.get_name() == "FLEX_ATTENTION", (
                            "Should fallback to FlexAttention if head size is "
                            "not supported by FlashAttention")

278

279
280
281
282
283
284
285
286
287
288
289
290
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_v1", [True, False])
def test_fp32_fallback(
    device: str,
    use_v1: bool,
    monkeypatch: pytest.MonkeyPatch,
):
    """Test attention backend selection with fp32."""
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")

        if device == "cpu":
291
292
293
            if not use_v1:
                pytest.skip("CPU backend only supports V1")

294
295
            with patch("vllm.attention.selector.current_platform",
                       CpuPlatform()):
296
                backend = get_attn_backend(16, torch.float32, None, 16, False)
297
            assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
298
299
300
301

        elif device == "cuda":
            with patch("vllm.attention.selector.current_platform",
                       CudaPlatform()):
302
                backend = get_attn_backend(16, torch.float32, None, 16, False)
303
304
305
306
            assert (backend.get_name() == "FLEX_ATTENTION"
                    if use_v1 else "XFORMERS")


307
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
308
    """Test FlashAttn validation."""
Joe Runde's avatar
Joe Runde committed
309
    # TODO: When testing for v1, pipe in `use_v1` as an argument to
310
    # get_attn_backend
311

312
313
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
314

315
        # Unsupported CUDA arch
316
317
318
        monkeypatch.setattr(torch.cuda,
                            "get_device_capability",
                            lambda _=None: (7, 5))
319
320
        backend = get_attn_backend(16, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
321

322
323
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
324

325
326
327
        # Unsupported data type
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
328

329
330
331
        # Unsupported kv cache data type
        backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
332

333
334
335
336
337
338
339
340
        # Unsupported block size
        backend = get_attn_backend(16, torch.float16, None, 8, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # flash-attn is not installed
        import sys
        original_module = sys.modules.get('vllm_flash_attn')
        monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
341
342
        backend = get_attn_backend(16, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
343

344
345
346
347
348
349
        # Restore the original module if it existed
        if original_module is not None:
            monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
                                original_module)
        else:
            monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
350

351
352
353
354
355
        # Unsupported head size
        backend = get_attn_backend(17, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # Attention-free models should bypass env and use PlaceholderAttention
356
        backend = get_attn_backend(16, torch.float16, None, 16, True)
357
        assert backend.get_name() != STR_FLASH_ATTN_VAL
358
359


360
@pytest.mark.parametrize("use_v1", [True, False])
361
def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):
362
    """Test that invalid attention backend names raise ValueError."""
363
364
365
366
    with monkeypatch.context() as m, patch(
            "vllm.attention.selector.current_platform", CudaPlatform()):
        m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
        m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
367

368
369
370
        # Should raise ValueError for invalid backend
        with pytest.raises(ValueError) as exc_info:
            get_attn_backend(32, torch.float16, None, 16, False)
371
        assert "Invalid value 'INVALID'" in str(exc_info.value)