test_mha_attn.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""
Test:

6
* Tests for MMEncoderAttention layer
7
"""
8

9
import itertools
10
11
from unittest.mock import patch

12
import numpy as np
13
14
15
import pytest
import torch

16
17
from vllm.config import get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
18
from vllm.model_executor.layers.attention import MMEncoderAttention
19
20
21
22
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
23
from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed
24
25
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend
26
27
28
29


@pytest.fixture(autouse=True)
def clear_cache():
30
    """Clear lru cache to ensure each test case runs without caching."""
31
32
33
    _cached_get_attn_backend.cache_clear()


34
35
36
37
38
39
40
41
devices = ["cpu"]
if current_platform.is_cuda():
    devices.append("cuda")
if current_platform.is_rocm():
    devices.append("hip")


@pytest.mark.parametrize("device", devices)
42
def test_mha_attn_platform(default_vllm_config, device: str):
43
    """
44
    Test the attention selector between different platform and device.
45
46
47
48
    """
    torch.set_default_dtype(torch.float16)

    if device == "cpu":
49
50
51
        with (
            patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
        ):
52
            attn = MMEncoderAttention(16, 64, scale=1)
53
            assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
54
    elif device == "hip":
55
56
57
        with (
            patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
        ):
58
            attn = MMEncoderAttention(16, 64, scale=1)
59
            assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
60
    else:
61
62
        # Test CUDA with head_size=64 (divisible by 32)
        # - should use vLLM's FlashAttention
63
64
65
        with (
            patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
        ):
66
            attn = MMEncoderAttention(16, 64, scale=1)
67
            assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
68

69
        # Test CUDA with head_size=72 (not divisible by 32)
70
        # - should use vLLM's FlashAttention
71
72
73
        with (
            patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
        ):
74
            attn = MMEncoderAttention(16, 72, scale=1)
75
            assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
76

77
78
79
80
81
82
83
84
85
        # Test CUDA with head_size=72 (not divisible by 32)
        # - should use vLLM's FlashAttention
        with (
            patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
            set_default_torch_dtype(torch.float32),
        ):
            attn = MMEncoderAttention(16, 72, scale=1)
            assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

def ref_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
) -> torch.Tensor:
    """
    Native implementation of scaled dot product attention without mask:
    - query, key, value: [batch_size, seq_len, num_heads, head_size]
    - attn_mask: [batch_size, seq_len, seq_len]
    """
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.matmul(attn_weights, value).transpose(1, 2)
    return out


BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
107
108
109
110
VAR_SEQ_LENS = [
    [2, 2],
    [2, 3, 4],
]
111
112
113
114
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
115
116
117
118
119
DTYPES = (
    [torch.half, torch.bfloat16, torch.float]
    if not current_platform.is_rocm()
    else [torch.half, torch.bfloat16]
)
120
121
122
123
124
125
126
127
128
129
130
CUDA_DEVICES = ["cuda"]


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_forward(
131
    default_vllm_config,
132
133
134
135
136
137
138
139
    batch_size: int,
    seq_len: int,
    num_heads: int,
    num_kv_heads: int,
    head_size: int,
    dtype: torch.dtype,
    device: str,
):
140
    set_random_seed(0)
141
142
143
144
145
146
147
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

    q = torch.randn(batch_size, seq_len, num_heads * head_size)
    k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
    v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
    scale = 1.0 / head_size**0.5
148
    attn = MMEncoderAttention(
149
150
        num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
    )
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    output = attn(q, k, v)

    assert num_heads % num_kv_heads == 0
    num_queries_per_kv = num_heads // num_kv_heads
    q = q.reshape(batch_size, seq_len, num_heads, head_size)
    k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
    v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
    if num_queries_per_kv > 1:
        k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
        v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)

    ref_output = ref_attention(
        q,
        k,
        v,
        scale=scale,
    ).reshape(batch_size, seq_len, num_heads * head_size)
168
169
170
171
172
173
    tol_kwargs = (
        dict(rtol=1e-3, atol=1e-3)
        if attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
        else {}
    )
    torch.testing.assert_close(output, ref_output, **tol_kwargs)
174
175
176
177
178
179
180
181
182


@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_varlen_forward(
183
    default_vllm_config,
184
185
186
187
188
189
190
    var_seq_len: list[int],
    num_heads: int,
    num_kv_heads: int,
    head_size: int,
    dtype: torch.dtype,
    device: str,
):
191
    set_random_seed(0)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

    q = torch.randn(1, sum(var_seq_len), num_heads, head_size)
    k = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
    v = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
    cu_seqlens = torch.tensor(
        [0] + list(itertools.accumulate(var_seq_len)), dtype=torch.int32
    )
    scale = 1.0 / head_size**0.5
    attn = MMEncoderAttention(
        num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
    )
    output = attn(
        q, k, v, cu_seqlens=cu_seqlens, max_seqlen=torch.tensor(max(var_seq_len))
    )

    assert num_heads % num_kv_heads == 0
    num_queries_per_kv = num_heads // num_kv_heads
    if num_queries_per_kv > 1:
        k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
        v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)

    ref_output = []
    for q_i, k_i, v_i in zip(
        torch.split(q, var_seq_len, dim=1),
        torch.split(k, var_seq_len, dim=1),
        torch.split(v, var_seq_len, dim=1),
    ):
        output_i = ref_attention(
            q_i,
            k_i,
            v_i,
            scale=scale,
        )
        ref_output.append(output_i)
    ref_output = torch.cat(ref_output, dim=1)
    torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336


@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
@pytest.mark.parametrize(
    "dtype",
    [torch.bfloat16, torch.half],
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_varlen_forward_flashinfer(
    default_vllm_config,
    var_seq_len: list[int],
    dtype: torch.dtype,
    device: str,
):
    """Test MMEncoderAttention varlen forward with FLASHINFER backend (head_size=72).

    Exercises the path that uses --mm-encoder-attn-backend=FLASHINFER with
    recomputed cu_seqlens, max_seqlen, and sequence_lengths as in qwen3_vl
    vision encoder.
    """
    pytest.importorskip("flashinfer")

    num_heads = 16
    head_size = 72
    set_random_seed(0)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

    # Override vllm config so get_vit_attn_backend returns FLASHINFER (simulates
    # --mm-encoder-attn-backend=FLASHINFER).
    vllm_config = get_current_vllm_config()
    old_model_config = getattr(vllm_config, "model_config", None)
    minimal_model_config = type(
        "MinimalModelConfig",
        (),
        {
            "multimodal_config": MultiModalConfig(
                mm_encoder_attn_backend=AttentionBackendEnum.FLASHINFER
            ),
        },
    )()
    vllm_config.model_config = minimal_model_config
    try:
        total_len = sum(var_seq_len)
        # Stride of second dim = 3 * num_heads * head_size (same as qwen2_5_vl
        # after qkv rearrange and unbind: qkv shape (b, s, 3, head, head_dim)).
        qkv = torch.randn(1, total_len, 3, num_heads, head_size)
        q, k, v = qkv.unbind(dim=2)

        cu_seqlens_np = np.array(
            [0] + list(itertools.accumulate(var_seq_len)), dtype=np.int32
        )
        hidden_size = num_heads * head_size
        tp_size = 1

        sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths(
            AttentionBackendEnum.FLASHINFER, cu_seqlens_np
        )
        sequence_lengths = torch.from_numpy(sequence_lengths_np).to(
            device, dtype=torch.int32, non_blocking=True
        )

        max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
            AttentionBackendEnum.FLASHINFER, cu_seqlens_np
        )
        max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32)

        cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens(
            AttentionBackendEnum.FLASHINFER,
            cu_seqlens_np,
            hidden_size,
            tp_size,
        )
        cu_seqlens = torch.from_numpy(cu_seqlens_np).to(
            device, dtype=torch.int32, non_blocking=True
        )

        scale = 1.0 / head_size**0.5
        attn = MMEncoderAttention(
            num_heads,
            head_size,
            scale=scale,
            num_kv_heads=num_heads,
        )
        assert attn.attn_backend == AttentionBackendEnum.FLASHINFER

        output = attn(
            q,
            k,
            v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            sequence_lengths=sequence_lengths,
        )

        ref_output = []
        for q_i, k_i, v_i in zip(
            torch.split(q, var_seq_len, dim=1),
            torch.split(k, var_seq_len, dim=1),
            torch.split(v, var_seq_len, dim=1),
        ):
            output_i = ref_attention(q_i, k_i, v_i, scale=scale)
            ref_output.append(output_i)
        ref_output = torch.cat(ref_output, dim=1)
        torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
    finally:
        vllm_config.model_config = old_model_config