test_attention_selector.py 11.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from unittest.mock import patch
5
6
7
8

import pytest
import torch

9
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
10
from vllm.platforms import current_platform
11
12
13
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
14
15


16
17
@pytest.fixture(autouse=True)
def clear_cache():
18
    """Clear lru cache to ensure each test case runs without caching."""
19
20
21
    _cached_get_attn_backend.cache_clear()


22
23
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
24
    "cuda": [
25
26
27
28
29
        "TRITON_MLA",
        "FLASHMLA",
        "FLASHINFER_MLA",
        "FLASH_ATTN_MLA",
        "CUTLASS_MLA",
30
    ],
31
32
33
34
35
    "hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
    "cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
36
    "cuda": ["FLASHINFER", "FLASH_ATTN"],
37
    "hip": ["ROCM_ATTN"],
38
    "cpu": ["CPU_ATTN"],
39
40
41
42
43
}

DEVICE_MLA_BLOCK_SIZES = {
    "cuda": [16, 64],  # CUDA supports both standard and extended block sizes
    "hip": [16, 1],  # HIP requires special handling for block_size=1
44
    # "cpu": [16]  # CPU uses fixed block size from test cases
45
    "cpu": [],  # FIXME(woosuk): Temporarily disable CPU tests
46
47
48
49
}


def generate_params():
50
    is_rocm = current_platform.is_rocm()
51
    params = []
52
    device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
53
    for use_mla in [True, False]:
54
        for device in device_list:
55
56
57
58
59
            backends = (
                DEVICE_MLA_BACKENDS[device]
                if use_mla
                else DEVICE_REGULAR_ATTN_BACKENDS[device]
            )
60
            for name in backends:
61
                block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
62
63
64
65
66
67
68
                for block_size in block_sizes:
                    params.append(
                        pytest.param(
                            device,
                            name,
                            use_mla,
                            block_size,
69
70
71
                            id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
                        )
                    )
72
73
74
    return params


75
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
76
def test_env(
77
    device: str,
78
    name: str,
79
80
    use_mla: bool,
    block_size: int,
81
82
    monkeypatch: pytest.MonkeyPatch,
):
83
    """Test attention backend selection with valid device-backend pairs."""
84
    with monkeypatch.context() as m:
85
        m.setenv("VLLM_ATTENTION_BACKEND", name)
86
        m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
87
88

        if device == "cpu":
89
            with patch("vllm.platforms.current_platform", CpuPlatform()):
90
                backend = get_attn_backend(16, torch.float16, None, block_size)
91
            assert backend.get_name() == "CPU_ATTN"
92

93
        elif device == "hip":
94
            with patch("vllm.platforms.current_platform", RocmPlatform()):
95
                if use_mla:
96
97
98
99
100
101
102
103
                    # ROCm MLA backend logic:
                    # - TRITON_MLA: supported when block_size != 1
                    # - ROCM_AITER_MLA: supported when block_size == 1
                    # If backend is forced but doesn't match block_size,
                    # should raise ValueError

                    if name == "TRITON_MLA" and block_size == 1:
                        # TRITON_MLA doesn't support block_size == 1
104
                        with pytest.raises(ValueError) as exc_info:
105
106
107
108
                            get_attn_backend(
                                16, torch.float16, None, block_size, use_mla=use_mla
                            )
                        assert f"The selected backend, {name}" in str(exc_info.value)
109
110
                    else:
                        # Valid backend-block_size combination
111
112
113
                        backend = get_attn_backend(
                            16, torch.float16, None, block_size, use_mla=use_mla
                        )
114
                        expected = name
115
                        assert backend.get_name() == expected
116
                else:
117
118
119
                    backend = get_attn_backend(
                        16, torch.float16, None, block_size, use_mla=use_mla
                    )
120
                    expected = "ROCM_ATTN"
121
122
123
                    assert backend.get_name() == expected

        elif device == "cuda":
124
            with patch("vllm.platforms.current_platform", CudaPlatform()):
125
                capability = torch.cuda.get_device_capability()
126
                if use_mla:
127
128
                    # CUDA MLA backend logic:
                    # - CUTLASS_MLA: only supported with block_size == 128
129
                    #   and Blackwell GPUs (SM 10.x), V1 only
130
                    # - FLASHINFER_MLA: only supported on Blackwell GPUs
131
                    #   (SM 10.x), V1 only
132
133
134
135
136
                    # - FLASHMLA: only supported with block_size == 64
                    # - FLASH_ATTN_MLA: V1 only
                    # - TRITON_MLA: fallback for other cases

                    if name == "CUTLASS_MLA":
137
                        if block_size != 128:
138
                            # CUTLASS_MLA only supports block_size == 128
139
                            pytest.skip("CUTLASS_MLA only supports block_size 128")
140
141
142
143
144
145
146
                        if capability[0] != 10:
                            pytest.skip("CUTLASS MLA is not supported on this platform")
                        backend = get_attn_backend(
                            576, torch.float16, None, block_size, use_mla=use_mla
                        )
                        expected = "CUTLASS_MLA"
                        assert backend.get_name() == expected
147
                    elif name == "FLASHINFER_MLA":
148
149
150
151
                        if capability[0] != 10:
                            pytest.skip(
                                "FlashInfer MLA is not supported on this platform"
                            )
152
                        if block_size not in [32, 64]:
153
154
                            # FlashInfer MLA only supports block_size 32 or 64
                            pytest.skip(
155
156
                                "FlashInfer MLA only supports block_size 32 or 64"
                            )
157
158
159
160
161
                        backend = get_attn_backend(
                            576, torch.float16, None, block_size, use_mla=use_mla
                        )
                        expected = "FLASHINFER_MLA"
                        assert backend.get_name() == expected
162
163
164
165
                    elif name == "FLASHMLA":
                        if block_size != 64:
                            # FlashMLA only supports block_size == 64
                            pytest.skip("FlashMLA only supports block_size 64")
166
167
168
                        from vllm.v1.attention.backends.mla.flashmla import (
                            is_flashmla_dense_supported,
                        )
169

170
171
172
173
174
175
176
177
178
179
180
181
                        is_supported, _ = is_flashmla_dense_supported()
                        if not is_supported:
                            pytest.skip("FlashMLA not supported on this platform")
                        backend = get_attn_backend(
                            576,
                            torch.float16,
                            None,
                            block_size,
                            use_mla=use_mla,
                        )
                        expected = name
                        assert backend.get_name() == expected
182
                    elif name == "FLASH_ATTN_MLA":
183
184
185
186
187
188
189
190
                        from vllm.attention.utils.fa_utils import (
                            flash_attn_supports_mla,
                        )

                        if not flash_attn_supports_mla():
                            pytest.skip(
                                "FlashAttention MLA not supported on this platform"
                            )
191
                        backend = get_attn_backend(
192
                            576, torch.float16, None, block_size, use_mla=use_mla
193
                        )
194
195
                        expected = "FLASH_ATTN_MLA"
                        assert backend.get_name() == expected
196
                    else:
197
                        # TRITON_MLA or other fallback
198
                        backend = get_attn_backend(
199
                            576, torch.float16, None, block_size, use_mla=use_mla
200
                        )
201
                        expected = "TRITON_MLA"
202
                        assert backend.get_name() == expected
203
                elif name == "FLASHINFER":
204
                    backend = get_attn_backend(
205
                        64, torch.float16, None, block_size, use_mla=use_mla
206
                    )
207
                    expected = "FLASHINFER"
208
                    assert backend.get_name() == expected
209
                elif name == "FLASH_ATTN":
210
211
212
                    backend = get_attn_backend(
                        32, torch.float16, None, block_size, use_mla=use_mla
                    )
213
214
                    expected = "FLASH_ATTN"
                    assert backend.get_name() == expected
215

216

217
@pytest.mark.parametrize("device", ["cpu", "cuda"])
218
def test_fp32_fallback(device: str):
219
    """Test attention backend selection with fp32."""
220
    if device == "cpu":
221
        with patch("vllm.platforms.current_platform", CpuPlatform()):
222
            backend = get_attn_backend(16, torch.float32, None, 16)
223
        assert backend.get_name() == "CPU_ATTN"
224

225
    elif device == "cuda":
226
        with patch("vllm.platforms.current_platform", CudaPlatform()):
227
228
            backend = get_attn_backend(16, torch.float32, None, 16)
        assert backend.get_name() == "FLEX_ATTENTION"
229
230


231
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
232
    """Test FlashAttn validation."""
233
234
235
236
    pytest.skip(
        "Skipping as current backend selector does not "
        "handle fallbacks when a backend is set via env var."
    )
237

238
    with monkeypatch.context() as m:
239
        m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
240

241
        # Unsupported CUDA arch
242
        monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
243
        backend = get_attn_backend(16, torch.float16, None, 16)
244
        assert backend.get_name() != "FLASH_ATTN"
245

246
247
        # Reset the monkeypatch for subsequent tests
        monkeypatch.undo()
248

249
        # Unsupported data type
250
        backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
251
        assert backend.get_name() != "FLASH_ATTN"
252

253
        # Unsupported kv cache data type
254
        backend = get_attn_backend(16, torch.float16, "fp8", 16)
255
        assert backend.get_name() != "FLASH_ATTN"
256

257
        # Unsupported block size
258
        backend = get_attn_backend(16, torch.float16, None, 8)
259
        assert backend.get_name() != "FLASH_ATTN"
260
261
262

        # flash-attn is not installed
        import sys
263
264
265

        original_module = sys.modules.get("vllm_flash_attn")
        monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
266
        backend = get_attn_backend(16, torch.float16, None, 16)
267
        assert backend.get_name() != "FLASH_ATTN"
268

269
270
        # Restore the original module if it existed
        if original_module is not None:
271
            monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
272
        else:
273
            monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
274

275
        # Unsupported head size
276
        backend = get_attn_backend(17, torch.float16, None, 16)
277
        assert backend.get_name() != "FLASH_ATTN"
278
279


280
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
281
    """Test that invalid attention backend names raise ValueError."""
282
283
    with (
        monkeypatch.context() as m,
284
        patch("vllm.platforms.current_platform", CudaPlatform()),
285
    ):
286
        m.setenv("VLLM_ATTENTION_BACKEND", "INVALID")
287

288
289
        # Should raise ValueError for invalid backend
        with pytest.raises(ValueError) as exc_info:
290
            get_attn_backend(32, torch.float16, None, 16)
291
        assert "Invalid value 'INVALID'" in str(exc_info.value)