test_attention_selector.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
10
11
12
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
13
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
14
15


16
17
18
19
20
21
22
@pytest.fixture(autouse=True)
def clear_cache():
    """Clear lru cache to ensure each test case runs without caching.
    """
    _cached_get_attn_backend.cache_clear()


23
24
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
25
26
27
28
    "cuda": [
        "TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
        "CUTLASS_MLA"
    ],
29
30
31
32
33
34
35
36
37
38
39
40
41
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
    "cuda": ["XFORMERS", "FLASHINFER"],
    "hip": ["ROCM_FLASH"],
    "cpu": ["TORCH_SDPA"],
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
42
43
    # "cpu": [16]  # CPU uses fixed block size from test cases
    "cpu": []  # FIXME(woosuk): Temporarily disable CPU tests
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
}


def generate_params():
    params = []
    for use_mla in [True, False]:
        for device in ["cuda", "hip", "cpu"]:
            backends = DEVICE_MLA_BACKENDS[
                device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
            for name in backends:
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
                    16
                ]
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
                            id=
                            f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
                        ))
    return params


@pytest.mark.parametrize("device, name, use_mla, block_size",
                         generate_params())
72
def test_env(
73
    device: str,
74
    name: str,
75
76
    use_mla: bool,
    block_size: int,
77
78
    monkeypatch: pytest.MonkeyPatch,
):
79
    """Test attention backend selection with valid device-backend pairs."""
80
    with monkeypatch.context() as m:
81
        m.setenv("VLLM_USE_V1", "1")
82
        m.setenv(STR_BACKEND_ENV_VAR, name)
83
        m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
84
85
86
87

        if device == "cpu":
            with patch("vllm.attention.selector.current_platform",
                       CpuPlatform()):
88
                backend = get_attn_backend(16, torch.float16, None, block_size)
89
            assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
90

91
        elif device == "hip":
92
            with patch("vllm.attention.selector.current_platform",
93
                       RocmPlatform()):
94
                if use_mla:
95
96
97
98
99
100
101
102
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
103
104
105
                        with pytest.raises(ValueError) as exc_info:
                            get_attn_backend(16,
                                             torch.float16,
106
                                             None,
107
108
109
110
                                             block_size,
                                             use_mla=use_mla)
                        assert f"The selected backend, {name}" in str(
                            exc_info.value)
111
112
113
114
115
                    elif name == "ROCM_AITER_MLA" and block_size != 1:
                        # ROCM_AITER_MLA only supports block_size == 1
                        with pytest.raises(ValueError) as exc_info:
                            get_attn_backend(16,
                                             torch.float16,
116
                                             None,
117
118
119
120
121
122
123
124
                                             block_size,
                                             use_mla=use_mla)
                        assert f"The selected backend, {name}" in str(
                            exc_info.value)
                    else:
                        # Valid backend-block_size combination
                        backend = get_attn_backend(16,
                                                   torch.float16,
125
                                                   None,
126
127
                                                   block_size,
                                                   use_mla=use_mla)
128
                        expected = f"{name}_VLLM_V1"
129
                        assert backend.get_name() == expected
130
131
132
                else:
                    backend = get_attn_backend(16,
                                               torch.float16,
133
                                               None,
134
135
                                               block_size,
                                               use_mla=use_mla)
136
                    expected = "TRITON_ATTN_VLLM_V1"
137
138
139
140
141
142
                    assert backend.get_name() == expected

        elif device == "cuda":
            with patch("vllm.attention.selector.current_platform",
                       CudaPlatform()):
                if use_mla:
143
144
145
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
                    #   and Blackwell GPUs (SM 10.0), V1 only
146
147
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
                    #   (SM 10.0+), V1 only
148
149
150
151
152
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
153
                        if block_size != 128:
154
155
156
157
158
159
                            # CUTLASS_MLA only supports block_size == 128
                            pytest.skip(
                                "CUTLASS_MLA only supports block_size 128")
                        else:
                            backend = get_attn_backend(16,
                                                       torch.float16,
160
                                                       None,
161
162
163
164
                                                       block_size,
                                                       use_mla=use_mla)
                            expected = "CUTLASS_MLA_VLLM_V1"
                            assert backend.get_name() == expected
165
                    elif name == "FLASHINFER_MLA":
166
                        if block_size not in [32, 64]:
167
168
169
170
171
172
173
174
175
176
177
178
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
                                "FlashInfer MLA only supports block_size 32 "
                                "or 64")
                        else:
                            backend = get_attn_backend(16,
                                                       torch.float16,
                                                       None,
                                                       block_size,
                                                       use_mla=use_mla)
                            expected = "FLASHINFER_MLA"
                            assert backend.get_name() == expected
179
180
181
182
183
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
                        else:
184
                            from vllm.v1.attention.backends.mla.flashmla import (  # noqa: E501
185
186
187
188
189
190
191
192
                                is_flashmla_supported)
                            is_supported, _ = is_flashmla_supported()
                            if not is_supported:
                                pytest.skip(
                                    "FlashMLA not supported on this platform")
                            else:
                                backend = get_attn_backend(16,
                                                           torch.float16,
193
                                                           None,
194
195
                                                           block_size,
                                                           use_mla=use_mla)
196
                                expected = f"{name}_VLLM_V1"
197
198
                                assert backend.get_name() == expected
                    elif name == "FLASH_ATTN_MLA":
199
200
201
202
203
204
205
                        backend = get_attn_backend(16,
                                                   torch.float16,
                                                   None,
                                                   block_size,
                                                   use_mla=use_mla)
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
206
                    else:
207
                        # TRITON_MLA or other fallback
208
209
                        backend = get_attn_backend(16,
                                                   torch.float16,
210
                                                   None,
211
212
                                                   block_size,
                                                   use_mla=use_mla)
213
                        expected = "TRITON_MLA_VLLM_V1"
214
                        assert backend.get_name() == expected
215
216
217
                elif name == "FLASHINFER":
                    backend = get_attn_backend(16,
                                               torch.float16,
218
                                               None,
219
220
                                               block_size,
                                               use_mla=use_mla)
221
                    expected = "FLASHINFER_VLLM_V1"
222
                    assert backend.get_name() == expected
223
                else:
224
                    backend = get_attn_backend(32,
225
                                               torch.float16,
226
                                               None,
227
228
                                               block_size,
                                               use_mla=use_mla)
229
                    expected = "FLASH_ATTN_VLLM_V1"
230
                    assert backend.get_name() == expected
231

232
233
234
235
236
237
238
239
                    backend = get_attn_backend(16,
                                               torch.float16,
                                               None,
                                               block_size,
                                               use_mla=use_mla)
                    assert backend.get_name() == "FLEX_ATTENTION", (
                        "Should fallback to FlexAttention if head size is "
                        "not supported by FlashAttention")
240

241

242
243
244
245
246
247
248
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_fp32_fallback(
    device: str,
    monkeypatch: pytest.MonkeyPatch,
):
    """Test attention backend selection with fp32."""
    with monkeypatch.context() as m:
249
        m.setenv("VLLM_USE_V1", "1")
250
251
252
253

        if device == "cpu":
            with patch("vllm.attention.selector.current_platform",
                       CpuPlatform()):
254
                backend = get_attn_backend(16, torch.float32, None, 16)
255
            assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
256
257
258
259

        elif device == "cuda":
            with patch("vllm.attention.selector.current_platform",
                       CudaPlatform()):
260
                backend = get_attn_backend(16, torch.float32, None, 16)
261
            assert backend.get_name() == "FLEX_ATTENTION"
262
263


264
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
265
    """Test FlashAttn validation."""
Joe Runde's avatar
Joe Runde committed
266
    # TODO: When testing for v1, pipe in `use_v1` as an argument to
267
    # get_attn_backend
268

269
270
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
271

272
        # Unsupported CUDA arch
273
274
275
        monkeypatch.setattr(torch.cuda,
                            "get_device_capability",
                            lambda _=None: (7, 5))
276
        backend = get_attn_backend(16, torch.float16, None, 16)
277
        assert backend.get_name() != STR_FLASH_ATTN_VAL
278

279
280
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
281

282
        # Unsupported data type
283
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
284
        assert backend.get_name() != STR_FLASH_ATTN_VAL
285

286
        # Unsupported kv cache data type
287
        backend = get_attn_backend(16, torch.float16, "fp8", 16)
288
        assert backend.get_name() != STR_FLASH_ATTN_VAL
289

290
        # Unsupported block size
291
        backend = get_attn_backend(16, torch.float16, None, 8)
292
293
294
295
296
297
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # flash-attn is not installed
        import sys
        original_module = sys.modules.get('vllm_flash_attn')
        monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
298
        backend = get_attn_backend(16, torch.float16, None, 16)
299
        assert backend.get_name() != STR_FLASH_ATTN_VAL
300

301
302
303
304
305
306
        # Restore the original module if it existed
        if original_module is not None:
            monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
                                original_module)
        else:
            monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
307

308
        # Unsupported head size
309
        backend = get_attn_backend(17, torch.float16, None, 16)
310
        assert backend.get_name() != STR_FLASH_ATTN_VAL
311
312


313
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
314
    """Test that invalid attention backend names raise ValueError."""
315
316
    with monkeypatch.context() as m, patch(
            "vllm.attention.selector.current_platform", CudaPlatform()):
317
        m.setenv("VLLM_USE_V1", "1")
318
        m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
319

320
321
        # Should raise ValueError for invalid backend
        with pytest.raises(ValueError) as exc_info:
322
            get_attn_backend(32, torch.float16, None, 16)
323
        assert "Invalid value 'INVALID'" in str(exc_info.value)