test_mha_attn.py 5.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""
Test:

* Tests for MultiHeadAttention layer
"""
8

9
10
11
12
13
from unittest.mock import patch

import pytest
import torch

14
from vllm.attention.backends.registry import AttentionBackendEnum
15
from vllm.attention.layer import MultiHeadAttention
16
from vllm.attention.selector import _cached_get_attn_backend
17
18
19
20
21
22
23
24
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform


@pytest.fixture(autouse=True)
def clear_cache():
25
    """Clear lru cache to ensure each test case runs without caching."""
26
27
28
    _cached_get_attn_backend.cache_clear()


29
30
31
32
33
34
35
36
devices = ["cpu"]
if current_platform.is_cuda():
    devices.append("cuda")
if current_platform.is_rocm():
    devices.append("hip")


@pytest.mark.parametrize("device", devices)
37
38
def test_mha_attn_platform(device: str):
    """
39
    Test the attention selector between different platform and device.
40
41
42
43
    """
    torch.set_default_dtype(torch.float16)

    if device == "cpu":
44
45
46
47
        with (
            patch("vllm.attention.layer.current_platform", CpuPlatform()),
            patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
        ):
48
            attn = MultiHeadAttention(16, 64, scale=1)
49
            assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
50
    elif device == "hip":
51
52
53
54
        with (
            patch("vllm.attention.layer.current_platform", RocmPlatform()),
            patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
        ):
55
            attn = MultiHeadAttention(16, 64, scale=1)
56
            assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
57
    else:
58
59
        # Test CUDA with head_size=64 (divisible by 32)
        # - should use vLLM's FlashAttention
60
61
62
63
        with (
            patch("vllm.attention.layer.current_platform", CudaPlatform()),
            patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
        ):
64
            attn = MultiHeadAttention(16, 64, scale=1)
65
            assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
66

67
        # Test CUDA with head_size=72 (not divisible by 32)
68
        # - should use vLLM's FlashAttention
69
70
71
72
        with (
            patch("vllm.attention.layer.current_platform", CudaPlatform()),
            patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
        ):
73
            attn = MultiHeadAttention(16, 72, scale=1)
74
            assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

def ref_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
) -> torch.Tensor:
    """
    Native implementation of scaled dot product attention without mask:
    - query, key, value: [batch_size, seq_len, num_heads, head_size]
    - attn_mask: [batch_size, seq_len, seq_len]
    """
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.matmul(attn_weights, value).transpose(1, 2)
    return out


BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
101
102
103
104
105
DTYPES = (
    [torch.half, torch.bfloat16, torch.float]
    if not current_platform.is_rocm()
    else [torch.half, torch.bfloat16]
)
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
CUDA_DEVICES = ["cuda"]


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_forward(
    batch_size: int,
    seq_len: int,
    num_heads: int,
    num_kv_heads: int,
    head_size: int,
    dtype: torch.dtype,
    device: str,
):
    current_platform.seed_everything(0)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

    q = torch.randn(batch_size, seq_len, num_heads * head_size)
    k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
    v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
    scale = 1.0 / head_size**0.5
133
134
135
    attn = MultiHeadAttention(
        num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
    )
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    output = attn(q, k, v)

    assert num_heads % num_kv_heads == 0
    num_queries_per_kv = num_heads // num_kv_heads
    q = q.reshape(batch_size, seq_len, num_heads, head_size)
    k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
    v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
    if num_queries_per_kv > 1:
        k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
        v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)

    ref_output = ref_attention(
        q,
        k,
        v,
        scale=scale,
    ).reshape(batch_size, seq_len, num_heads * head_size)
    torch.testing.assert_close(output, ref_output)