test_attention_selector.py 15.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
10
11
12
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
13
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
14
15


16
17
18
19
20
21
22
@pytest.fixture(autouse=True)
def clear_cache():
    """Clear lru cache to ensure each test case runs without caching.
    """
    _cached_get_attn_backend.cache_clear()


23
24
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
25
26
27
28
    "cuda": [
        "TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
        "CUTLASS_MLA"
    ],
29
30
31
32
33
34
35
36
37
38
39
40
41
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
    "cuda": ["XFORMERS", "FLASHINFER"],
    "hip": ["ROCM_FLASH"],
    "cpu": ["TORCH_SDPA"],
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
42
43
    # "cpu": [16]  # CPU uses fixed block size from test cases
    "cpu": []  # FIXME(woosuk): Temporarily disable CPU tests
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
}


def generate_params():
    params = []
    for use_mla in [True, False]:
        for device in ["cuda", "hip", "cpu"]:
            backends = DEVICE_MLA_BACKENDS[
                device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
            for name in backends:
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
                    16
                ]
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
                            id=
                            f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
                        ))
    return params


@pytest.mark.parametrize("device, name, use_mla, block_size",
                         generate_params())
72
def test_env(
73
    device: str,
74
    name: str,
75
76
    use_mla: bool,
    block_size: int,
77
78
    monkeypatch: pytest.MonkeyPatch,
):
79
    """Test attention backend selection with valid device-backend pairs."""
80
    with monkeypatch.context() as m:
81
        m.setenv("VLLM_USE_V1", "1")
82
        m.setenv(STR_BACKEND_ENV_VAR, name)
83
        m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
84
85
86
87

        if device == "cpu":
            with patch("vllm.attention.selector.current_platform",
                       CpuPlatform()):
88
89
                backend = get_attn_backend(16, torch.float16, None, block_size,
                                           False)
90
            assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
91

92
        elif device == "hip":
93
            with patch("vllm.attention.selector.current_platform",
94
                       RocmPlatform()):
95
                if use_mla:
96
97
98
99
100
101
102
103
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
104
105
106
                        with pytest.raises(ValueError) as exc_info:
                            get_attn_backend(16,
                                             torch.float16,
107
                                             None,
108
109
110
111
112
                                             block_size,
                                             False,
                                             use_mla=use_mla)
                        assert f"The selected backend, {name}" in str(
                            exc_info.value)
113
114
115
116
117
                    elif name == "ROCM_AITER_MLA" and block_size != 1:
                        # ROCM_AITER_MLA only supports block_size == 1
                        with pytest.raises(ValueError) as exc_info:
                            get_attn_backend(16,
                                             torch.float16,
118
                                             None,
119
120
121
122
123
124
125
126
127
                                             block_size,
                                             False,
                                             use_mla=use_mla)
                        assert f"The selected backend, {name}" in str(
                            exc_info.value)
                    else:
                        # Valid backend-block_size combination
                        backend = get_attn_backend(16,
                                                   torch.float16,
128
                                                   None,
129
130
131
                                                   block_size,
                                                   False,
                                                   use_mla=use_mla)
132
                        expected = f"{name}_VLLM_V1"
133
                        assert backend.get_name() == expected
134
135
136
                else:
                    backend = get_attn_backend(16,
                                               torch.float16,
137
                                               None,
138
139
140
                                               block_size,
                                               False,
                                               use_mla=use_mla)
141
                    expected = "TRITON_ATTN_VLLM_V1"
142
143
144
145
146
147
                    assert backend.get_name() == expected

        elif device == "cuda":
            with patch("vllm.attention.selector.current_platform",
                       CudaPlatform()):
                if use_mla:
148
149
150
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
                    #   and Blackwell GPUs (SM 10.0), V1 only
151
152
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
                    #   (SM 10.0+), V1 only
153
154
155
156
157
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
158
                        if block_size != 128:
159
160
161
162
163
164
                            # CUTLASS_MLA only supports block_size == 128
                            pytest.skip(
                                "CUTLASS_MLA only supports block_size 128")
                        else:
                            backend = get_attn_backend(16,
                                                       torch.float16,
165
                                                       None,
166
167
168
169
170
                                                       block_size,
                                                       False,
                                                       use_mla=use_mla)
                            expected = "CUTLASS_MLA_VLLM_V1"
                            assert backend.get_name() == expected
171
                    elif name == "FLASHINFER_MLA":
172
                        if block_size not in [32, 64]:
173
174
175
176
177
178
179
180
181
182
183
184
185
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
                                "FlashInfer MLA only supports block_size 32 "
                                "or 64")
                        else:
                            backend = get_attn_backend(16,
                                                       torch.float16,
                                                       None,
                                                       block_size,
                                                       False,
                                                       use_mla=use_mla)
                            expected = "FLASHINFER_MLA"
                            assert backend.get_name() == expected
186
187
188
189
190
191
192
193
194
195
196
197
198
199
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
                        else:
                            from vllm.attention.backends.flashmla import (
                                is_flashmla_supported)
                            is_supported, _ = is_flashmla_supported()
                            if not is_supported:
                                pytest.skip(
                                    "FlashMLA not supported on this platform")
                            else:
                                backend = get_attn_backend(16,
                                                           torch.float16,
200
                                                           None,
201
202
203
                                                           block_size,
                                                           False,
                                                           use_mla=use_mla)
204
                                expected = f"{name}_VLLM_V1"
205
206
                                assert backend.get_name() == expected
                    elif name == "FLASH_ATTN_MLA":
207
208
209
210
211
212
213
214
                        backend = get_attn_backend(16,
                                                   torch.float16,
                                                   None,
                                                   block_size,
                                                   False,
                                                   use_mla=use_mla)
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
215
                    else:
216
                        # TRITON_MLA or other fallback
217
218
                        backend = get_attn_backend(16,
                                                   torch.float16,
219
                                                   None,
220
221
222
                                                   block_size,
                                                   False,
                                                   use_mla=use_mla)
223
                        expected = "TRITON_MLA_VLLM_V1"
224
                        assert backend.get_name() == expected
225
226
227
                elif name == "FLASHINFER":
                    backend = get_attn_backend(16,
                                               torch.float16,
228
                                               None,
229
230
231
                                               block_size,
                                               False,
                                               use_mla=use_mla)
232
                    expected = "FLASHINFER_VLLM_V1"
233
                    assert backend.get_name() == expected
234
                else:
235
                    backend = get_attn_backend(32,
236
                                               torch.float16,
237
                                               None,
238
239
240
                                               block_size,
                                               False,
                                               use_mla=use_mla)
241
                    expected = "FLASH_ATTN_VLLM_V1"
242
                    assert backend.get_name() == expected
243

244
245
246
247
248
249
250
251
252
                    backend = get_attn_backend(16,
                                               torch.float16,
                                               None,
                                               block_size,
                                               False,
                                               use_mla=use_mla)
                    assert backend.get_name() == "FLEX_ATTENTION", (
                        "Should fallback to FlexAttention if head size is "
                        "not supported by FlashAttention")
253

254

255
256
257
258
259
260
261
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_fp32_fallback(
    device: str,
    monkeypatch: pytest.MonkeyPatch,
):
    """Test attention backend selection with fp32."""
    with monkeypatch.context() as m:
262
        m.setenv("VLLM_USE_V1", "1")
263
264
265
266

        if device == "cpu":
            with patch("vllm.attention.selector.current_platform",
                       CpuPlatform()):
267
                backend = get_attn_backend(16, torch.float32, None, 16, False)
268
            assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
269
270
271
272

        elif device == "cuda":
            with patch("vllm.attention.selector.current_platform",
                       CudaPlatform()):
273
                backend = get_attn_backend(16, torch.float32, None, 16, False)
274
            assert backend.get_name() == "FLEX_ATTENTION"
275
276


277
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
278
    """Test FlashAttn validation."""
Joe Runde's avatar
Joe Runde committed
279
    # TODO: When testing for v1, pipe in `use_v1` as an argument to
280
    # get_attn_backend
281

282
283
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
284

285
        # Unsupported CUDA arch
286
287
288
        monkeypatch.setattr(torch.cuda,
                            "get_device_capability",
                            lambda _=None: (7, 5))
289
290
        backend = get_attn_backend(16, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
291

292
293
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
294

295
296
297
        # Unsupported data type
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
298

299
300
301
        # Unsupported kv cache data type
        backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
302

303
304
305
306
307
308
309
310
        # Unsupported block size
        backend = get_attn_backend(16, torch.float16, None, 8, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # flash-attn is not installed
        import sys
        original_module = sys.modules.get('vllm_flash_attn')
        monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
311
312
        backend = get_attn_backend(16, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL
313

314
315
316
317
318
319
        # Restore the original module if it existed
        if original_module is not None:
            monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
                                original_module)
        else:
            monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
320

321
322
323
324
325
        # Unsupported head size
        backend = get_attn_backend(17, torch.float16, None, 16, False)
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # Attention-free models should bypass env and use PlaceholderAttention
326
        backend = get_attn_backend(16, torch.float16, None, 16, True)
327
        assert backend.get_name() != STR_FLASH_ATTN_VAL
328
329


330
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
331
    """Test that invalid attention backend names raise ValueError."""
332
333
    with monkeypatch.context() as m, patch(
            "vllm.attention.selector.current_platform", CudaPlatform()):
334
        m.setenv("VLLM_USE_V1", "1")
335
        m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
336

337
338
339
        # Should raise ValueError for invalid backend
        with pytest.raises(ValueError) as exc_info:
            get_attn_backend(32, torch.float16, None, 16, False)
340
        assert "Invalid value 'INVALID'" in str(exc_info.value)