test_attention_selector.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
10
11
12
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
13
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
14
15


16
17
@pytest.fixture(autouse=True)
def clear_cache():
18
    """Clear lru cache to ensure each test case runs without caching."""
19
20
21
    _cached_get_attn_backend.cache_clear()


22
23
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
24
    "cuda": [
25
26
27
28
29
        "TRITON_MLA",
        "FLASHMLA",
        "FLASHINFER_MLA",
        "FLASH_ATTN_MLA",
        "CUTLASS_MLA",
30
    ],
31
32
33
34
35
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
36
    "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
37
    "hip": ["ROCM_ATTN"],
38
39
40
41
42
43
    "cpu": ["TORCH_SDPA"],
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
44
    # "cpu": [16]  # CPU uses fixed block size from test cases
45
    "cpu": [],  # FIXME(woosuk): Temporarily disable CPU tests
46
47
48
49
50
51
52
}


def generate_params():
    params = []
    for use_mla in [True, False]:
        for device in ["cuda", "hip", "cpu"]:
53
54
55
56
57
            backends = (
                DEVICE_MLA_BACKENDS[device]
                if use_mla
                else DEVICE_REGULAR_ATTN_BACKENDS[device]
            )
58
            for name in backends:
59
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
60
61
62
63
64
65
66
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
67
68
69
                            id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
                        )
                    )
70
71
72
    return params


73
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
74
def test_env(
75
    device: str,
76
    name: str,
77
78
    use_mla: bool,
    block_size: int,
79
80
    monkeypatch: pytest.MonkeyPatch,
):
81
    """Test attention backend selection with valid device-backend pairs."""
82
83
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, name)
84
        m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
85
86

        if device == "cpu":
87
            with patch("vllm.platforms.current_platform", CpuPlatform()):
88
                backend = get_attn_backend(16, torch.float16, None, block_size)
89
            assert backend.get_name() == "TORCH_SDPA"
90

91
        elif device == "hip":
92
            with patch("vllm.platforms.current_platform", RocmPlatform()):
93
                if use_mla:
94
95
96
97
98
99
100
101
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
102
                        with pytest.raises(ValueError) as exc_info:
103
104
105
106
                            get_attn_backend(
                                16, torch.float16, None, block_size, use_mla=use_mla
                            )
                        assert f"The selected backend, {name}" in str(exc_info.value)
107
108
                    else:
                        # Valid backend-block_size combination
109
110
111
                        backend = get_attn_backend(
                            16, torch.float16, None, block_size, use_mla=use_mla
                        )
112
                        expected = name
113
                        assert backend.get_name() == expected
114
                else:
115
116
117
                    backend = get_attn_backend(
                        16, torch.float16, None, block_size, use_mla=use_mla
                    )
118
                    expected = "ROCM_ATTN"
119
120
121
                    assert backend.get_name() == expected

        elif device == "cuda":
122
            with patch("vllm.platforms.current_platform", CudaPlatform()):
123
                capability = torch.cuda.get_device_capability()
124
                if use_mla:
125
126
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
127
                    #   and Blackwell GPUs (SM 10.x), V1 only
128
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
129
                    #   (SM 10.x), V1 only
130
131
132
133
134
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
135
                        if block_size != 128:
136
                            # CUTLASS_MLA only supports block_size == 128
137
                            pytest.skip("CUTLASS_MLA only supports block_size 128")
138
139
140
141
142
143
144
                        if capability[0] != 10:
                            pytest.skip("CUTLASS MLA is not supported on this platform")
                        backend = get_attn_backend(
                            576, torch.float16, None, block_size, use_mla=use_mla
                        )
                        expected = "CUTLASS_MLA"
                        assert backend.get_name() == expected
145
                    elif name == "FLASHINFER_MLA":
146
147
148
149
                        if capability[0] != 10:
                            pytest.skip(
                                "FlashInfer MLA is not supported on this platform"
                            )
150
                        if block_size not in [32, 64]:
151
152
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
153
154
                                "FlashInfer MLA only supports block_size 32 or 64"
                            )
155
156
157
158
159
                        backend = get_attn_backend(
                            576, torch.float16, None, block_size, use_mla=use_mla
                        )
                        expected = "FLASHINFER_MLA"
                        assert backend.get_name() == expected
160
161
162
163
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
164
165
166
                        from vllm.v1.attention.backends.mla.flashmla import (
                            is_flashmla_dense_supported,
                        )
167

168
169
170
171
172
173
174
175
176
177
178
179
                        is_supported, _ = is_flashmla_dense_supported()
                        if not is_supported:
                            pytest.skip("FlashMLA not supported on this platform")
                        backend = get_attn_backend(
                            576,
                            torch.float16,
                            None,
                            block_size,
                            use_mla=use_mla,
                        )
                        expected = name
                        assert backend.get_name() == expected
180
                    elif name == "FLASH_ATTN_MLA":
181
182
183
184
185
186
187
188
                        from vllm.attention.utils.fa_utils import (
                            flash_attn_supports_mla,
                        )

                        if not flash_attn_supports_mla():
                            pytest.skip(
                                "FlashAttention MLA not supported on this platform"
                            )
189
                        backend = get_attn_backend(
190
                            576, torch.float16, None, block_size, use_mla=use_mla
191
                        )
192
193
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
194
                    else:
195
                        # TRITON_MLA or other fallback
196
                        backend = get_attn_backend(
197
                            576, torch.float16, None, block_size, use_mla=use_mla
198
                        )
199
                        expected = "TRITON_MLA"
200
                        assert backend.get_name() == expected
201
                elif name == "FLASHINFER":
202
                    backend = get_attn_backend(
203
                        64, torch.float16, None, block_size, use_mla=use_mla
204
                    )
205
                    expected = "FLASHINFER"
206
                    assert backend.get_name() == expected
207
                elif name == "XFORMERS":
208
209
210
                    backend = get_attn_backend(
                        32, torch.float16, None, block_size, use_mla=use_mla
                    )
211
                    expected = "XFORMERS"
212
                    assert backend.get_name() == expected
213
                elif name == "FLASH_ATTN":
214
215
216
                    backend = get_attn_backend(
                        32, torch.float16, None, block_size, use_mla=use_mla
                    )
217
218
                    expected = "FLASH_ATTN"
                    assert backend.get_name() == expected
219

220

221
@pytest.mark.parametrize("device", ["cpu", "cuda"])
222
def test_fp32_fallback(device: str):
223
    """Test attention backend selection with fp32."""
224
    if device == "cpu":
225
        with patch("vllm.platforms.current_platform", CpuPlatform()):
226
227
            backend = get_attn_backend(16, torch.float32, None, 16)
        assert backend.get_name() == "TORCH_SDPA"
228

229
    elif device == "cuda":
230
        with patch("vllm.platforms.current_platform", CudaPlatform()):
231
232
            backend = get_attn_backend(16, torch.float32, None, 16)
        assert backend.get_name() == "FLEX_ATTENTION"
233
234


235
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
236
    """Test FlashAttn validation."""
237
238
239
240
    pytest.skip(
        "Skipping as current backend selector does not "
        "handle fallbacks when a backend is set via env var."
    )
241

242
243
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
244

245
        # Unsupported CUDA arch
246
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
247
        backend = get_attn_backend(16, torch.float16, None, 16)
248
        assert backend.get_name() != STR_FLASH_ATTN_VAL
249

250
251
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
252

253
        # Unsupported data type
254
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
255
        assert backend.get_name() != STR_FLASH_ATTN_VAL
256

257
        # Unsupported kv cache data type
258
        backend = get_attn_backend(16, torch.float16, "fp8", 16)
259
        assert backend.get_name() != STR_FLASH_ATTN_VAL
260

261
        # Unsupported block size
262
        backend = get_attn_backend(16, torch.float16, None, 8)
263
264
265
266
        assert backend.get_name() != STR_FLASH_ATTN_VAL

        # flash-attn is not installed
        import sys
267
268
269

        original_module = sys.modules.get("vllm_flash_attn")
        monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
270
        backend = get_attn_backend(16, torch.float16, None, 16)
271
        assert backend.get_name() != STR_FLASH_ATTN_VAL
272

273
274
        # Restore the original module if it existed
        if original_module is not None:
275
            monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
276
        else:
277
            monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
278

279
        # Unsupported head size
280
        backend = get_attn_backend(17, torch.float16, None, 16)
281
        assert backend.get_name() != STR_FLASH_ATTN_VAL
282
283


284
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
285
    """Test that invalid attention backend names raise ValueError."""
286
287
    with (
        monkeypatch.context() as m,
288
        patch("vllm.platforms.current_platform", CudaPlatform()),
289
    ):
290
        m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
291

292
293
        # Should raise ValueError for invalid backend
        with pytest.raises(ValueError) as exc_info:
294
            get_attn_backend(32, torch.float16, None, 16)
295
        assert "Invalid value 'INVALID'" in str(exc_info.value)