test_attention_selector.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
10
11
12
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
13
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
14
15


16
17
18
19
20
21
22
@pytest.fixture(autouse=True)
def clear_cache():
    """Clear lru cache to ensure each test case runs without caching.
    """
    _cached_get_attn_backend.cache_clear()


23
24
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
25
26
27
28
    "cuda": [
        "TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
        "CUTLASS_MLA"
    ],
29
30
31
32
33
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
34
    "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
35
36
37
38
39
40
41
    "hip": ["ROCM_FLASH"],
    "cpu": ["TORCH_SDPA"],
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
42
43
    # "cpu": [16]  # CPU uses fixed block size from test cases
    "cpu": []  # FIXME(woosuk): Temporarily disable CPU tests
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
}


def generate_params():
    params = []
    for use_mla in [True, False]:
        for device in ["cuda", "hip", "cpu"]:
            backends = DEVICE_MLA_BACKENDS[
                device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
            for name in backends:
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
                    16
                ]
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
                            id=
                            f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
                        ))
    return params


@pytest.mark.parametrize("device, name, use_mla, block_size",
                         generate_params())
72
def test_env(
73
    device: str,
74
    name: str,
75
76
    use_mla: bool,
    block_size: int,
77
78
    monkeypatch: pytest.MonkeyPatch,
):
79
    """Test attention backend selection with valid device-backend pairs."""
80
    with monkeypatch.context() as m:
81
        m.setenv("VLLM_USE_V1", "1")
82
        m.setenv(STR_BACKEND_ENV_VAR, name)
83
        m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
84
85
86
87

        if device == "cpu":
            with patch("vllm.attention.selector.current_platform",
                       CpuPlatform()):
88
                backend = get_attn_backend(16, torch.float16, None, block_size)
89
            assert backend.get_name() == "TORCH_SDPA"
90

91
        elif device == "hip":
92
            with patch("vllm.attention.selector.current_platform",
93
                       RocmPlatform()):
94
                if use_mla:
95
96
97
98
99
100
101
102
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
103
104
105
                        with pytest.raises(ValueError) as exc_info:
                            get_attn_backend(16,
                                             torch.float16,
106
                                             None,
107
108
109
110
                                             block_size,
                                             use_mla=use_mla)
                        assert f"The selected backend, {name}" in str(
                            exc_info.value)
111
112
113
114
115
                    elif name == "ROCM_AITER_MLA" and block_size != 1:
                        # ROCM_AITER_MLA only supports block_size == 1
                        with pytest.raises(ValueError) as exc_info:
                            get_attn_backend(16,
                                             torch.float16,
116
                                             None,
117
118
119
120
121
122
123
124
                                             block_size,
                                             use_mla=use_mla)
                        assert f"The selected backend, {name}" in str(
                            exc_info.value)
                    else:
                        # Valid backend-block_size combination
                        backend = get_attn_backend(16,
                                                   torch.float16,
125
                                                   None,
126
127
                                                   block_size,
                                                   use_mla=use_mla)
128
                        expected = name
129
                        assert backend.get_name() == expected
130
131
132
                else:
                    backend = get_attn_backend(16,
                                               torch.float16,
133
                                               None,
134
135
                                               block_size,
                                               use_mla=use_mla)
136
                    expected = "TRITON_ATTN"
137
138
139
140
141
142
                    assert backend.get_name() == expected

        elif device == "cuda":
            with patch("vllm.attention.selector.current_platform",
                       CudaPlatform()):
                if use_mla:
143
144
145
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
                    #   and Blackwell GPUs (SM 10.0), V1 only
146
147
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
                    #   (SM 10.0+), V1 only
148
149
150
151
152
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
153
                        if block_size != 128:
154
155
156
157
158
159
                            # CUTLASS_MLA only supports block_size == 128
                            pytest.skip(
                                "CUTLASS_MLA only supports block_size 128")
                        else:
                            backend = get_attn_backend(16,
                                                       torch.float16,
160
                                                       None,
161
162
                                                       block_size,
                                                       use_mla=use_mla)
163
                            expected = "CUTLASS_MLA"
164
                            assert backend.get_name() == expected
165
                    elif name == "FLASHINFER_MLA":
166
                        if block_size not in [32, 64]:
167
168
169
170
171
172
173
174
175
176
177
178
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
                                "FlashInfer MLA only supports block_size 32 "
                                "or 64")
                        else:
                            backend = get_attn_backend(16,
                                                       torch.float16,
                                                       None,
                                                       block_size,
                                                       use_mla=use_mla)
                            expected = "FLASHINFER_MLA"
                            assert backend.get_name() == expected
179
180
181
182
183
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
                        else:
184
                            from vllm.v1.attention.backends.mla.flashmla import (  # noqa: E501
185
186
187
188
189
190
191
192
                                is_flashmla_supported)
                            is_supported, _ = is_flashmla_supported()
                            if not is_supported:
                                pytest.skip(
                                    "FlashMLA not supported on this platform")
                            else:
                                backend = get_attn_backend(16,
                                                           torch.float16,
193
                                                           None,
194
195
                                                           block_size,
                                                           use_mla=use_mla)
196
                                expected = name
197
198
                                assert backend.get_name() == expected
                    elif name == "FLASH_ATTN_MLA":
199
200
201
202
203
204
205
                        backend = get_attn_backend(16,
                                                   torch.float16,
                                                   None,
                                                   block_size,
                                                   use_mla=use_mla)
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
206
                    else:
207
                        # TRITON_MLA or other fallback
208
209
                        backend = get_attn_backend(16,
                                                   torch.float16,
210
                                                   None,
211
212
                                                   block_size,
                                                   use_mla=use_mla)
213
                        expected = "TRITON_MLA"
214
                        assert backend.get_name() == expected
215
216
217
                elif name == "FLASHINFER":
                    backend = get_attn_backend(16,
                                               torch.float16,
218
                                               None,
219
220
                                               block_size,
                                               use_mla=use_mla)
221
                    expected = "FLASHINFER"
222
                    assert backend.get_name() == expected
223
                elif name == "XFORMERS":
224
                    backend = get_attn_backend(32,
225
                                               torch.float16,
226
                                               None,
227
228
                                               block_size,
                                               use_mla=use_mla)
229
                    expected = "XFORMERS"
230
                    assert backend.get_name() == expected
231
232
                elif name == "FLASH_ATTN":
                    backend = get_attn_backend(32,
233
234
235
236
                                               torch.float16,
                                               None,
                                               block_size,
                                               use_mla=use_mla)
237
238
                    expected = "FLASH_ATTN"
                    assert backend.get_name() == expected
239

240

241
242
243
244
245
246
247
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_fp32_fallback(
    device: str,
    monkeypatch: pytest.MonkeyPatch,
):
    """Test attention backend selection with fp32."""
    with monkeypatch.context() as m:
248
        m.setenv("VLLM_USE_V1", "1")
249
250
251
252

        if device == "cpu":
            with patch("vllm.attention.selector.current_platform",
                       CpuPlatform()):
253
                backend = get_attn_backend(16, torch.float32, None, 16)
254
            assert backend.get_name() == "TORCH_SDPA"
255
256
257
258

        elif device == "cuda":
            with patch("vllm.attention.selector.current_platform",
                       CudaPlatform()):
259
                backend = get_attn_backend(16, torch.float32, None, 16)
260
            assert backend.get_name() == "FLEX_ATTENTION"
261
262


263
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
264
    """Test FlashAttn validation."""
Joe Runde's avatar
Joe Runde committed
265
    # TODO: When testing for v1, pipe in `use_v1` as an argument to
266
    # get_attn_backend
267

268
269
270
    pytest.skip("Skipping as current backend selector does not " \
                "handle fallbacks when a backend is set via env var.")

271
272
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
273

274
        # Unsupported CUDA arch
275
276
277
        monkeypatch.setattr(torch.cuda,
                            "get_device_capability",
                            lambda _=None: (7, 5))
278
        backend = get_attn_backend(16, torch.float16, None, 16)
279
        assert backend.get_name() != STR_FLASH_ATTN_VAL
280

281
282
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
283

284
        # Unsupported data type
285
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
286
        assert backend.get_name() != STR_FLASH_ATTN_VAL
287

288
        # Unsupported kv cache data type
289
        backend = get_attn_backend(16, torch.float16, "fp8", 16)
290
        assert backend.get_name() != STR_FLASH_ATTN_VAL
291

292
        # Unsupported block size
293
        backend = get_attn_backend(16, torch.float16, None, 8)
294
295
296
297
298
299
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # flash-attn is not installed
        import sys
        original_module = sys.modules.get('vllm_flash_attn')
        monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
300
        backend = get_attn_backend(16, torch.float16, None, 16)
301
        assert backend.get_name() != STR_FLASH_ATTN_VAL
302

303
304
305
306
307
308
        # Restore the original module if it existed
        if original_module is not None:
            monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
                                original_module)
        else:
            monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
309

310
        # Unsupported head size
311
        backend = get_attn_backend(17, torch.float16, None, 16)
312
        assert backend.get_name() != STR_FLASH_ATTN_VAL
313
314


315
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
316
    """Test that invalid attention backend names raise ValueError."""
317
318
    with monkeypatch.context() as m, patch(
            "vllm.attention.selector.current_platform", CudaPlatform()):
319
        m.setenv("VLLM_USE_V1", "1")
320
        m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
321

322
323
        # Should raise ValueError for invalid backend
        with pytest.raises(ValueError) as exc_info:
324
            get_attn_backend(32, torch.float16, None, 16)
325
        assert "Invalid value 'INVALID'" in str(exc_info.value)