test_mha_attn.py 6.79 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""
Test:

6
* Tests for MMEncoderAttention layer
7
"""
8

9
import itertools
10
11
12
13
14
from unittest.mock import patch

import pytest
import torch

15
from vllm.model_executor.layers.attention import MMEncoderAttention
16
17
18
19
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
20
from vllm.utils.torch_utils import set_random_seed
21
22
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend
23
24
25
26


@pytest.fixture(autouse=True)
def clear_cache():
27
    """Clear lru cache to ensure each test case runs without caching."""
28
29
30
    _cached_get_attn_backend.cache_clear()


31
32
33
34
35
36
37
38
devices = ["cpu"]
if current_platform.is_cuda():
    devices.append("cuda")
if current_platform.is_rocm():
    devices.append("hip")


@pytest.mark.parametrize("device", devices)
39
def test_mha_attn_platform(default_vllm_config, device: str):
40
    """
41
    Test the attention selector between different platform and device.
42
43
44
45
    """
    torch.set_default_dtype(torch.float16)

    if device == "cpu":
46
47
48
        with (
            patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
        ):
49
            attn = MMEncoderAttention(16, 64, scale=1)
50
            assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
51
    elif device == "hip":
52
53
54
        with (
            patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
        ):
55
            attn = MMEncoderAttention(16, 64, scale=1)
56
            assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
57
    else:
58
59
        # Test CUDA with head_size=64 (divisible by 32)
        # - should use vLLM's FlashAttention
60
61
62
        with (
            patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
        ):
63
            attn = MMEncoderAttention(16, 64, scale=1)
64
            assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
65

66
        # Test CUDA with head_size=72 (not divisible by 32)
67
        # - should use vLLM's FlashAttention
68
69
70
        with (
            patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
        ):
71
            attn = MMEncoderAttention(16, 72, scale=1)
72
            assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

def ref_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
) -> torch.Tensor:
    """
    Native implementation of scaled dot product attention without mask:
    - query, key, value: [batch_size, seq_len, num_heads, head_size]
    - attn_mask: [batch_size, seq_len, seq_len]
    """
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.matmul(attn_weights, value).transpose(1, 2)
    return out


BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
95
96
97
98
VAR_SEQ_LENS = [
    [2, 2],
    [2, 3, 4],
]
99
100
101
102
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
103
104
105
106
107
DTYPES = (
    [torch.half, torch.bfloat16, torch.float]
    if not current_platform.is_rocm()
    else [torch.half, torch.bfloat16]
)
108
109
110
111
112
113
114
115
116
117
118
CUDA_DEVICES = ["cuda"]


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_forward(
119
    default_vllm_config,
120
121
122
123
124
125
126
127
    batch_size: int,
    seq_len: int,
    num_heads: int,
    num_kv_heads: int,
    head_size: int,
    dtype: torch.dtype,
    device: str,
):
128
    set_random_seed(0)
129
130
131
132
133
134
135
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

    q = torch.randn(batch_size, seq_len, num_heads * head_size)
    k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
    v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
    scale = 1.0 / head_size**0.5
136
    attn = MMEncoderAttention(
137
138
        num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
    )
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    output = attn(q, k, v)

    assert num_heads % num_kv_heads == 0
    num_queries_per_kv = num_heads // num_kv_heads
    q = q.reshape(batch_size, seq_len, num_heads, head_size)
    k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
    v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
    if num_queries_per_kv > 1:
        k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
        v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)

    ref_output = ref_attention(
        q,
        k,
        v,
        scale=scale,
    ).reshape(batch_size, seq_len, num_heads * head_size)
    torch.testing.assert_close(output, ref_output)
157
158
159
160
161
162
163
164
165


@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_varlen_forward(
166
    default_vllm_config,
167
168
169
170
171
172
173
    var_seq_len: list[int],
    num_heads: int,
    num_kv_heads: int,
    head_size: int,
    dtype: torch.dtype,
    device: str,
):
174
    set_random_seed(0)
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

    q = torch.randn(1, sum(var_seq_len), num_heads, head_size)
    k = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
    v = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
    cu_seqlens = torch.tensor(
        [0] + list(itertools.accumulate(var_seq_len)), dtype=torch.int32
    )
    scale = 1.0 / head_size**0.5
    attn = MMEncoderAttention(
        num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
    )
    output = attn(
        q, k, v, cu_seqlens=cu_seqlens, max_seqlen=torch.tensor(max(var_seq_len))
    )

    assert num_heads % num_kv_heads == 0
    num_queries_per_kv = num_heads // num_kv_heads
    if num_queries_per_kv > 1:
        k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
        v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)

    ref_output = []
    for q_i, k_i, v_i in zip(
        torch.split(q, var_seq_len, dim=1),
        torch.split(k, var_seq_len, dim=1),
        torch.split(v, var_seq_len, dim=1),
    ):
        output_i = ref_attention(
            q_i,
            k_i,
            v_i,
            scale=scale,
        )
        ref_output.append(output_i)
    ref_output = torch.cat(ref_output, dim=1)
    torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)