"vllm/vscode:/vscode.git/clone" did not exist on "d143271234454026454c5ee6a55fc516dd298dac"
test_flash_attn.py 7.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7

import pytest
import torch

8
from vllm.platforms import current_platform
9
10
11
12
13
14
15
16
17
18
19
20
21
22

try:
    from vllm.vllm_flash_attn import (
        fa_version_unsupported_reason,
        flash_attn_varlen_func,
        is_fa_version_supported,
    )
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "vllm_flash_attn is not supported for vLLM on ROCm.",
            allow_module_level=True,
        )

23

24
NUM_HEADS = [(4, 4), (8, 2)]
25
HEAD_SIZES = [40, 72, 80, 128, 256]
26
27
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
28
QDTYPES = [None, torch.float8_e4m3fn]
29
30
31
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
32
SOFT_CAPS = [None]
33
SLIDING_WINDOWS = [None, 256]
34
35
36
37
38
39


def ref_paged_attn(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
40
41
    query_lens: list[int],
    kv_lens: list[int],
42
43
    block_tables: torch.Tensor,
    scale: float,
44
45
    sliding_window: int | None = None,
    soft_cap: float | None = None,
46
47
48
49
50
) -> torch.Tensor:
    num_seqs = len(query_lens)
    block_tables = block_tables.cpu().numpy()
    _, block_size, num_kv_heads, head_size = key_cache.shape

51
    outputs: list[torch.Tensor] = []
52
53
54
55
    start_idx = 0
    for i in range(num_seqs):
        query_len = query_lens[i]
        kv_len = kv_lens[i]
56
        q = query[start_idx : start_idx + query_len]
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        q *= scale

        num_kv_blocks = (kv_len + block_size - 1) // block_size
        block_indices = block_tables[i, :num_kv_blocks]

        k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
        k = k[:kv_len]
        v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
        v = v[:kv_len]

        if q.shape[1] != k.shape[1]:
            k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
            v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
        attn = torch.einsum("qhd,khd->hqk", q, k).float()
        empty_mask = torch.ones(query_len, kv_len)
        mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
        if sliding_window is not None:
74
75
76
77
78
79
80
            sliding_window_mask = (
                torch.triu(
                    empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
                )
                .bool()
                .logical_not()
            )
81
            mask |= sliding_window_mask
82
83
        if soft_cap is not None:
            attn = soft_cap * torch.tanh(attn / soft_cap)
84
85
86
87
88
89
90
91
92
93
        attn.masked_fill_(mask, float("-inf"))
        attn = torch.softmax(attn, dim=-1).to(v.dtype)
        out = torch.einsum("hqk,khd->qhd", attn, v)

        outputs.append(out)
        start_idx += query_len

    return torch.cat(outputs, dim=0)


94
@pytest.mark.parametrize("use_out", [True, False])
95
96
97
@pytest.mark.parametrize(
    "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
)
98
99
100
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
101
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
102
@pytest.mark.parametrize("dtype", DTYPES)
103
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
104
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
105
@pytest.mark.parametrize("fa_version", [2, 3])
106
@pytest.mark.parametrize("q_dtype", QDTYPES)
107
@torch.inference_mode()
108
def test_varlen_with_paged_kv(
109
    use_out: bool,
110
111
    seq_lens: list[tuple[int, int]],
    num_heads: tuple[int, int],
112
    head_size: int,
113
    sliding_window: int | None,
114
115
    dtype: torch.dtype,
    block_size: int,
116
    soft_cap: float | None,
117
    num_blocks: int,
118
    fa_version: int,
119
    q_dtype: torch.dtype | None,
120
121
) -> None:
    torch.set_default_device("cuda")
122
    if not is_fa_version_supported(fa_version):
123
124
125
126
        pytest.skip(
            f"Flash attention version {fa_version} not supported due "
            f'to: "{fa_version_unsupported_reason(fa_version)}"'
        )
127
    if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
128
129
130
131
        pytest.skip(
            "Flash attention with quantized inputs is only "
            "supported on version 3 with bfloat16 base type"
        )
132
    current_platform.seed_everything(0)
133
134
135
136
137
138
139
140
    num_seqs = len(seq_lens)
    query_lens = [x[0] for x in seq_lens]
    kv_lens = [x[1] for x in seq_lens]
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_query_len = max(query_lens)
    max_kv_len = max(kv_lens)
141
    window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
142
143
    scale = head_size**-0.5

144
145
146
147
    query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
    key_cache = torch.randn(
        num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
    )
148
    value_cache = torch.randn_like(key_cache)
149
150
151
    cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
        dim=0, dtype=torch.int32
    )
152
    kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
153
154

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
155
156
157
    block_tables = torch.randint(
        0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
    )
158

159
    out = torch.empty_like(query) if use_out else None
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    maybe_quantized_query = query
    maybe_quantized_key_cache = key_cache
    maybe_quantized_value_cache = value_cache
    q_descale = None
    k_descale = None
    v_descale = None
    if q_dtype is not None:
        # QKV are drawn from N(0, 1): no need for a fp8 scaling factor
        maybe_quantized_query = query.to(q_dtype)
        maybe_quantized_key_cache = key_cache.to(q_dtype)
        maybe_quantized_value_cache = value_cache.to(q_dtype)

        scale_shape = (num_seqs, num_kv_heads)
        q_descale = torch.ones(scale_shape, dtype=torch.float32)
        k_descale = torch.ones(scale_shape, dtype=torch.float32)
        v_descale = torch.ones(scale_shape, dtype=torch.float32)

178
    output = flash_attn_varlen_func(
179
180
181
        q=maybe_quantized_query,
        k=maybe_quantized_key_cache,
        v=maybe_quantized_value_cache,
182
        out=out,
183
        cu_seqlens_q=cu_query_lens,
184
        seqused_k=kv_lens,
185
186
187
188
189
190
        max_seqlen_q=max_query_len,
        max_seqlen_k=max_kv_len,
        softmax_scale=scale,
        causal=True,
        window_size=window_size,
        block_table=block_tables,
191
        softcap=soft_cap if soft_cap is not None else 0,
192
        fa_version=fa_version,
193
194
195
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
196
    )
197
    output = output if not use_out else out
198
199
200
201
202
203
204
205
206
207

    ref_output = ref_paged_attn(
        query=query,
        key_cache=key_cache,
        value_cache=value_cache,
        query_lens=query_lens,
        kv_lens=kv_lens,
        block_tables=block_tables,
        scale=scale,
        sliding_window=sliding_window,
208
        soft_cap=soft_cap,
209
    )
210
211
212
    atol, rtol = 1.5e-2, 1e-2
    if q_dtype is not None:
        atol, rtol = 1.5e-1, 1.5e-1
213
214
215
216
    (
        torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
        f"{torch.max(torch.abs(output - ref_output))}",
    )