test_mha_attn.py 6.08 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""
Test:

* Tests for MultiHeadAttention layer
"""
8

9
10
11
12
13
from unittest.mock import patch

import pytest
import torch

14
from vllm.attention.backends.registry import _Backend
15
from vllm.attention.layer import MultiHeadAttention
16
from vllm.attention.selector import _cached_get_attn_backend
17
18
19
20
21
22
23
24
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform


@pytest.fixture(autouse=True)
def clear_cache():
25
    """Clear lru cache to ensure each test case runs without caching."""
26
    _cached_get_attn_backend.cache_clear()
27
28
    # Clear xformers availability cache
    import vllm.attention.layer as layer_module
29

30
    layer_module.USE_XFORMERS_OPS = None
31
32
33
34
35


@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_mha_attn_platform(device: str):
    """
36
    Test the attention selector between different platform and device.
37
38
39
40
    """
    torch.set_default_dtype(torch.float16)

    if device == "cpu":
41
42
43
44
        with (
            patch("vllm.attention.layer.current_platform", CpuPlatform()),
            patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
        ):
45
            attn = MultiHeadAttention(16, 64, scale=1)
46
            assert attn.attn_backend == _Backend.TORCH_SDPA
47
    elif device == "hip":
48
49
50
51
        with (
            patch("vllm.attention.layer.current_platform", RocmPlatform()),
            patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
        ):
52
53
54
            attn = MultiHeadAttention(16, 64, scale=1)
            assert attn.attn_backend == _Backend.TORCH_SDPA
    else:
55
56
        # Test CUDA with head_size=64 (divisible by 32)
        # - should use vLLM's FlashAttention
57
58
59
60
        with (
            patch("vllm.attention.layer.current_platform", CudaPlatform()),
            patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
        ):
61
            attn = MultiHeadAttention(16, 64, scale=1)
62
            assert attn.attn_backend == _Backend.FLASH_ATTN
63

64
65
66
        # Test CUDA with head_size=72 (not divisible by 32)
        # - with upstream FA not available
        # - should use xformers
67
68
69
70
71
72
73
74
        with (
            patch("vllm.attention.layer.current_platform", CudaPlatform()),
            patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
            patch(
                "vllm.attention.layer.check_upstream_fa_availability",
                return_value=False,
            ),
        ):
75
76
77
            attn = MultiHeadAttention(16, 72, scale=1)
            assert attn.attn_backend == _Backend.XFORMERS

78
79
80
        # Test CUDA with head_size=72 (not divisible by 32)
        # - with upstream FA available
        # - should use upstream FA
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        with (
            patch("vllm.attention.layer.current_platform", CudaPlatform()),
            patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
            patch(
                "vllm.attention.layer.check_upstream_fa_availability", return_value=True
            ),
            patch.dict(
                "sys.modules",
                {
                    "flash_attn": type(
                        "MockFlashAttn",
                        (),
                        {"flash_attn_varlen_func": lambda *args, **kwargs: None},
                    )()
                },
            ),
        ):
98
99
100
            attn = MultiHeadAttention(16, 72, scale=1)
            assert attn.attn_backend == _Backend.FLASH_ATTN

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

def ref_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
) -> torch.Tensor:
    """
    Native implementation of scaled dot product attention without mask:
    - query, key, value: [batch_size, seq_len, num_heads, head_size]
    - attn_mask: [batch_size, seq_len, seq_len]
    """
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.matmul(attn_weights, value).transpose(1, 2)
    return out


BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
126
127
128
129
130
DTYPES = (
    [torch.half, torch.bfloat16, torch.float]
    if not current_platform.is_rocm()
    else [torch.half, torch.bfloat16]
)
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
CUDA_DEVICES = ["cuda"]


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_forward(
    batch_size: int,
    seq_len: int,
    num_heads: int,
    num_kv_heads: int,
    head_size: int,
    dtype: torch.dtype,
    device: str,
):
    current_platform.seed_everything(0)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

    q = torch.randn(batch_size, seq_len, num_heads * head_size)
    k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
    v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
    scale = 1.0 / head_size**0.5
158
159
160
    attn = MultiHeadAttention(
        num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
    )
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    output = attn(q, k, v)

    assert num_heads % num_kv_heads == 0
    num_queries_per_kv = num_heads // num_kv_heads
    q = q.reshape(batch_size, seq_len, num_heads, head_size)
    k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
    v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
    if num_queries_per_kv > 1:
        k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
        v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)

    ref_output = ref_attention(
        q,
        k,
        v,
        scale=scale,
    ).reshape(batch_size, seq_len, num_heads * head_size)
    torch.testing.assert_close(output, ref_output)