test_attention_selector.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
10
11
12
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
13
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
14
15


16
17
@pytest.fixture(autouse=True)
def clear_cache():
18
    """Clear lru cache to ensure each test case runs without caching."""
19
20
21
    _cached_get_attn_backend.cache_clear()


22
23
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
24
    "cuda": [
25
26
27
28
29
        "TRITON_MLA",
        "FLASHMLA",
        "FLASHINFER_MLA",
        "FLASH_ATTN_MLA",
        "CUTLASS_MLA",
30
    ],
31
32
33
34
35
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
36
    "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
37
38
39
40
41
42
43
    "hip": ["ROCM_FLASH"],
    "cpu": ["TORCH_SDPA"],
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
44
    # "cpu": [16]  # CPU uses fixed block size from test cases
45
    "cpu": [],  # FIXME(woosuk): Temporarily disable CPU tests
46
47
48
49
50
51
52
}


def generate_params():
    params = []
    for use_mla in [True, False]:
        for device in ["cuda", "hip", "cpu"]:
53
54
55
56
57
            backends = (
                DEVICE_MLA_BACKENDS[device]
                if use_mla
                else DEVICE_REGULAR_ATTN_BACKENDS[device]
            )
58
            for name in backends:
59
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
60
61
62
63
64
65
66
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
67
68
69
                            id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
                        )
                    )
70
71
72
    return params


73
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
74
def test_env(
75
    device: str,
76
    name: str,
77
78
    use_mla: bool,
    block_size: int,
79
80
    monkeypatch: pytest.MonkeyPatch,
):
81
    """Test attention backend selection with valid device-backend pairs."""
82
    with monkeypatch.context() as m:
83
        m.setenv("VLLM_USE_V1", "1")
84
        m.setenv(STR_BACKEND_ENV_VAR, name)
85
        m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
86
87

        if device == "cpu":
88
            with patch("vllm.attention.selector.current_platform", CpuPlatform()):
89
                backend = get_attn_backend(16, torch.float16, None, block_size)
90
            assert backend.get_name() == "TORCH_SDPA"
91

92
        elif device == "hip":
93
            with patch("vllm.attention.selector.current_platform", RocmPlatform()):
94
                if use_mla:
95
96
97
98
99
100
101
102
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
103
                        with pytest.raises(ValueError) as exc_info:
104
105
106
107
                            get_attn_backend(
                                16, torch.float16, None, block_size, use_mla=use_mla
                            )
                        assert f"The selected backend, {name}" in str(exc_info.value)
108
109
110
                    elif name == "ROCM_AITER_MLA" and block_size != 1:
                        # ROCM_AITER_MLA only supports block_size == 1
                        with pytest.raises(ValueError) as exc_info:
111
112
113
114
                            get_attn_backend(
                                16, torch.float16, None, block_size, use_mla=use_mla
                            )
                        assert f"The selected backend, {name}" in str(exc_info.value)
115
116
                    else:
                        # Valid backend-block_size combination
117
118
119
                        backend = get_attn_backend(
                            16, torch.float16, None, block_size, use_mla=use_mla
                        )
120
                        expected = name
121
                        assert backend.get_name() == expected
122
                else:
123
124
125
                    backend = get_attn_backend(
                        16, torch.float16, None, block_size, use_mla=use_mla
                    )
126
                    expected = "TRITON_ATTN"
127
128
129
                    assert backend.get_name() == expected

        elif device == "cuda":
130
            with patch("vllm.attention.selector.current_platform", CudaPlatform()):
131
                if use_mla:
132
133
134
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
                    #   and Blackwell GPUs (SM 10.0), V1 only
135
136
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
                    #   (SM 10.0+), V1 only
137
138
139
140
141
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
142
                        if block_size != 128:
143
                            # CUTLASS_MLA only supports block_size == 128
144
                            pytest.skip("CUTLASS_MLA only supports block_size 128")
145
                        else:
146
147
148
                            backend = get_attn_backend(
                                16, torch.float16, None, block_size, use_mla=use_mla
                            )
149
                            expected = "CUTLASS_MLA"
150
                            assert backend.get_name() == expected
151
                    elif name == "FLASHINFER_MLA":
152
                        if block_size not in [32, 64]:
153
154
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
155
156
                                "FlashInfer MLA only supports block_size 32 or 64"
                            )
157
                        else:
158
159
160
                            backend = get_attn_backend(
                                16, torch.float16, None, block_size, use_mla=use_mla
                            )
161
162
                            expected = "FLASHINFER_MLA"
                            assert backend.get_name() == expected
163
164
165
166
167
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
                        else:
168
                            from vllm.v1.attention.backends.mla.flashmla import (  # noqa: E501
169
170
171
                                is_flashmla_supported,
                            )

172
173
                            is_supported, _ = is_flashmla_supported()
                            if not is_supported:
174
                                pytest.skip("FlashMLA not supported on this platform")
175
                            else:
176
177
178
                                backend = get_attn_backend(
                                    16, torch.float16, None, block_size, use_mla=use_mla
                                )
179
                                expected = name
180
181
                                assert backend.get_name() == expected
                    elif name == "FLASH_ATTN_MLA":
182
183
184
                        backend = get_attn_backend(
                            16, torch.float16, None, block_size, use_mla=use_mla
                        )
185
186
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
187
                    else:
188
                        # TRITON_MLA or other fallback
189
190
191
                        backend = get_attn_backend(
                            16, torch.float16, None, block_size, use_mla=use_mla
                        )
192
                        expected = "TRITON_MLA"
193
                        assert backend.get_name() == expected
194
                elif name == "FLASHINFER":
195
196
197
                    backend = get_attn_backend(
                        16, torch.float16, None, block_size, use_mla=use_mla
                    )
198
                    expected = "FLASHINFER"
199
                    assert backend.get_name() == expected
200
                elif name == "XFORMERS":
201
202
203
                    backend = get_attn_backend(
                        32, torch.float16, None, block_size, use_mla=use_mla
                    )
204
                    expected = "XFORMERS"
205
                    assert backend.get_name() == expected
206
                elif name == "FLASH_ATTN":
207
208
209
                    backend = get_attn_backend(
                        32, torch.float16, None, block_size, use_mla=use_mla
                    )
210
211
                    expected = "FLASH_ATTN"
                    assert backend.get_name() == expected
212

213

214
215
216
217
218
219
220
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_fp32_fallback(
    device: str,
    monkeypatch: pytest.MonkeyPatch,
):
    """Test attention backend selection with fp32."""
    with monkeypatch.context() as m:
221
        m.setenv("VLLM_USE_V1", "1")
222
223

        if device == "cpu":
224
            with patch("vllm.attention.selector.current_platform", CpuPlatform()):
225
                backend = get_attn_backend(16, torch.float32, None, 16)
226
            assert backend.get_name() == "TORCH_SDPA"
227
228

        elif device == "cuda":
229
            with patch("vllm.attention.selector.current_platform", CudaPlatform()):
230
                backend = get_attn_backend(16, torch.float32, None, 16)
231
            assert backend.get_name() == "FLEX_ATTENTION"
232
233


234
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
235
    """Test FlashAttn validation."""
Joe Runde's avatar
Joe Runde committed
236
    # TODO: When testing for v1, pipe in `use_v1` as an argument to
237
    # get_attn_backend
238

239
240
241
242
    pytest.skip(
        "Skipping as current backend selector does not "
        "handle fallbacks when a backend is set via env var."
    )
243

244
245
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
246

247
        # Unsupported CUDA arch
248
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
249
        backend = get_attn_backend(16, torch.float16, None, 16)
250
        assert backend.get_name() != STR_FLASH_ATTN_VAL
251

252
253
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
254

255
        # Unsupported data type
256
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
257
        assert backend.get_name() != STR_FLASH_ATTN_VAL
258

259
        # Unsupported kv cache data type
260
        backend = get_attn_backend(16, torch.float16, "fp8", 16)
261
        assert backend.get_name() != STR_FLASH_ATTN_VAL
262

263
        # Unsupported block size
264
        backend = get_attn_backend(16, torch.float16, None, 8)
265
266
267
268
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # flash-attn is not installed
        import sys
269
270
271

        original_module = sys.modules.get("vllm_flash_attn")
        monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
272
        backend = get_attn_backend(16, torch.float16, None, 16)
273
        assert backend.get_name() != STR_FLASH_ATTN_VAL
274

275
276
        # Restore the original module if it existed
        if original_module is not None:
277
            monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
278
        else:
279
            monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
280

281
        # Unsupported head size
282
        backend = get_attn_backend(17, torch.float16, None, 16)
283
        assert backend.get_name() != STR_FLASH_ATTN_VAL
284
285


286
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
287
    """Test that invalid attention backend names raise ValueError."""
288
289
290
291
    with (
        monkeypatch.context() as m,
        patch("vllm.attention.selector.current_platform", CudaPlatform()),
    ):
292
        m.setenv("VLLM_USE_V1", "1")
293
        m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
294

295
296
        # Should raise ValueError for invalid backend
        with pytest.raises(ValueError) as exc_info:
297
            get_attn_backend(32, torch.float16, None, 16)
298
        assert "Invalid value 'INVALID'" in str(exc_info.value)