test_flash_attn.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional
5
6
7
8

import pytest
import torch

9
from vllm.platforms import current_platform
10
11
12
13
14
15
from vllm.vllm_flash_attn import (
    fa_version_unsupported_reason,
    flash_attn_varlen_func,
    flash_attn_with_kvcache,
    is_fa_version_supported,
)
16

17
NUM_HEADS = [(4, 4), (8, 2)]
18
HEAD_SIZES = [128, 256]
19
20
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
21
QDTYPES = [None, torch.float8_e4m3fn]
22
23
24
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
25
26
SOFT_CAPS = [None, 50.0]
SLIDING_WINDOWS = [None, 256]
27
28
29
30
31
32


def ref_paged_attn(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
33
34
    query_lens: list[int],
    kv_lens: list[int],
35
36
37
    block_tables: torch.Tensor,
    scale: float,
    sliding_window: Optional[int] = None,
38
    soft_cap: Optional[float] = None,
39
40
41
42
43
) -> torch.Tensor:
    num_seqs = len(query_lens)
    block_tables = block_tables.cpu().numpy()
    _, block_size, num_kv_heads, head_size = key_cache.shape

44
    outputs: list[torch.Tensor] = []
45
46
47
48
    start_idx = 0
    for i in range(num_seqs):
        query_len = query_lens[i]
        kv_len = kv_lens[i]
49
        q = query[start_idx : start_idx + query_len]
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        q *= scale

        num_kv_blocks = (kv_len + block_size - 1) // block_size
        block_indices = block_tables[i, :num_kv_blocks]

        k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
        k = k[:kv_len]
        v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
        v = v[:kv_len]

        if q.shape[1] != k.shape[1]:
            k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
            v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
        attn = torch.einsum("qhd,khd->hqk", q, k).float()
        empty_mask = torch.ones(query_len, kv_len)
        mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
        if sliding_window is not None:
67
68
69
70
71
72
73
            sliding_window_mask = (
                torch.triu(
                    empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
                )
                .bool()
                .logical_not()
            )
74
            mask |= sliding_window_mask
75
76
        if soft_cap is not None:
            attn = soft_cap * torch.tanh(attn / soft_cap)
77
78
79
80
81
82
83
84
85
86
        attn.masked_fill_(mask, float("-inf"))
        attn = torch.softmax(attn, dim=-1).to(v.dtype)
        out = torch.einsum("hqk,khd->qhd", attn, v)

        outputs.append(out)
        start_idx += query_len

    return torch.cat(outputs, dim=0)


87
@pytest.mark.parametrize("use_out", [True, False])
88
89
90
91
92
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
93
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
94
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
95
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
96
@pytest.mark.parametrize("fa_version", [2, 3])
97
@pytest.mark.parametrize("q_dtype", QDTYPES)
98
@torch.inference_mode()
99
def test_flash_attn_with_paged_kv(
100
    use_out: bool,
101
102
    kv_lens: list[int],
    num_heads: tuple[int, int],
103
104
105
    head_size: int,
    dtype: torch.dtype,
    block_size: int,
106
    soft_cap: Optional[float],
107
    num_blocks: int,
108
    sliding_window: Optional[int],
109
    fa_version: int,
110
    q_dtype: Optional[torch.dtype],
111
112
) -> None:
    torch.set_default_device("cuda")
113
    if not is_fa_version_supported(fa_version):
114
115
116
117
        pytest.skip(
            f"Flash attention version {fa_version} not supported due "
            f'to: "{fa_version_unsupported_reason(fa_version)}"'
        )
118
    if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
119
120
121
122
        pytest.skip(
            "Flash attention with quantized inputs is only "
            "supported on version 3 with bfloat16 base type"
        )
123

124
    current_platform.seed_everything(0)
125
126
127
128
129
130
    num_seqs = len(kv_lens)
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_kv_len = max(kv_lens)
    scale = head_size**-0.5
131
    window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
132
133

    query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
134
135
136
    key_cache = torch.randn(
        num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
    )
137
138
139
140
    value_cache = torch.randn_like(key_cache)
    kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
141
142
143
    block_tables = torch.randint(
        0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
    )
144

145
146
    q = query.unsqueeze(1)
    out = torch.empty_like(q) if use_out else None
147
148
149
150
151
152
153
154
155

    maybe_quantized_query = q
    maybe_quantized_key_cache = key_cache
    maybe_quantized_value_cache = value_cache
    q_descale = None
    k_descale = None
    v_descale = None
    if q_dtype is not None:
        # QKV are drawn from N(0, 1): no need for a fp8 scaling factor
Happy's avatar
Happy committed
156
        maybe_quantized_query = q.to(q_dtype)
157
158
159
160
161
162
163
164
        maybe_quantized_key_cache = key_cache.to(q_dtype)
        maybe_quantized_value_cache = value_cache.to(q_dtype)

        scale_shape = (num_seqs, num_kv_heads)
        q_descale = torch.ones(scale_shape, dtype=torch.float32)
        k_descale = torch.ones(scale_shape, dtype=torch.float32)
        v_descale = torch.ones(scale_shape, dtype=torch.float32)

165
    output = flash_attn_with_kvcache(
166
167
168
        q=maybe_quantized_query,
        k_cache=maybe_quantized_key_cache,
        v_cache=maybe_quantized_value_cache,
169
        out=out,
170
171
172
173
        softmax_scale=scale,
        causal=True,
        block_table=block_tables,
        cache_seqlens=kv_lens_tensor,
174
        softcap=soft_cap if soft_cap is not None else 0,
175
        window_size=window_size,
176
        fa_version=fa_version,
177
178
179
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
180
181
182
    )
    output = output if not use_out else out
    output = output.squeeze(1)
183

184
185
186
187
    atol, rtol = 1.5e-2, 1e-2
    if q_dtype is not None:
        atol, rtol = 1.5e-1, 1.5e-1

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    ref_output = ref_paged_attn(
        query=query,
        key_cache=key_cache,
        value_cache=value_cache,
        query_lens=[1] * num_seqs,
        kv_lens=kv_lens,
        block_tables=block_tables,
        scale=scale,
        soft_cap=soft_cap,
        sliding_window=sliding_window,
    )
    (
        torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
        f"{torch.max(torch.abs(output - ref_output))}",
    )
203
204


205
@pytest.mark.parametrize("use_out", [True, False])
206
207
208
@pytest.mark.parametrize(
    "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
)
209
210
211
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
212
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
213
@pytest.mark.parametrize("dtype", DTYPES)
214
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
215
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
216
@pytest.mark.parametrize("fa_version", [2, 3])
217
@pytest.mark.parametrize("q_dtype", QDTYPES)
218
@torch.inference_mode()
219
def test_varlen_with_paged_kv(
220
    use_out: bool,
221
222
    seq_lens: list[tuple[int, int]],
    num_heads: tuple[int, int],
223
224
225
226
    head_size: int,
    sliding_window: Optional[int],
    dtype: torch.dtype,
    block_size: int,
227
    soft_cap: Optional[float],
228
    num_blocks: int,
229
    fa_version: int,
230
    q_dtype: Optional[torch.dtype],
231
232
) -> None:
    torch.set_default_device("cuda")
233
    if not is_fa_version_supported(fa_version):
234
235
236
237
        pytest.skip(
            f"Flash attention version {fa_version} not supported due "
            f'to: "{fa_version_unsupported_reason(fa_version)}"'
        )
238
    if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
239
240
241
242
        pytest.skip(
            "Flash attention with quantized inputs is only "
            "supported on version 3 with bfloat16 base type"
        )
243
    current_platform.seed_everything(0)
244
245
246
247
248
249
250
251
    num_seqs = len(seq_lens)
    query_lens = [x[0] for x in seq_lens]
    kv_lens = [x[1] for x in seq_lens]
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_query_len = max(query_lens)
    max_kv_len = max(kv_lens)
252
    window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
253
254
    scale = head_size**-0.5

255
256
257
258
    query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
    key_cache = torch.randn(
        num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
    )
259
    value_cache = torch.randn_like(key_cache)
260
261
262
    cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
        dim=0, dtype=torch.int32
    )
263
    kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
264
265

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
266
267
268
    block_tables = torch.randint(
        0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
    )
269

270
    out = torch.empty_like(query) if use_out else None
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

    maybe_quantized_query = query
    maybe_quantized_key_cache = key_cache
    maybe_quantized_value_cache = value_cache
    q_descale = None
    k_descale = None
    v_descale = None
    if q_dtype is not None:
        # QKV are drawn from N(0, 1): no need for a fp8 scaling factor
        maybe_quantized_query = query.to(q_dtype)
        maybe_quantized_key_cache = key_cache.to(q_dtype)
        maybe_quantized_value_cache = value_cache.to(q_dtype)

        scale_shape = (num_seqs, num_kv_heads)
        q_descale = torch.ones(scale_shape, dtype=torch.float32)
        k_descale = torch.ones(scale_shape, dtype=torch.float32)
        v_descale = torch.ones(scale_shape, dtype=torch.float32)

289
    output = flash_attn_varlen_func(
290
291
292
        q=maybe_quantized_query,
        k=maybe_quantized_key_cache,
        v=maybe_quantized_value_cache,
293
        out=out,
294
        cu_seqlens_q=cu_query_lens,
295
        seqused_k=kv_lens,
296
297
298
299
300
301
        max_seqlen_q=max_query_len,
        max_seqlen_k=max_kv_len,
        softmax_scale=scale,
        causal=True,
        window_size=window_size,
        block_table=block_tables,
302
        softcap=soft_cap if soft_cap is not None else 0,
303
        fa_version=fa_version,
304
305
306
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
307
    )
308
    output = output if not use_out else out
309
310
311
312
313
314
315
316
317
318

    ref_output = ref_paged_attn(
        query=query,
        key_cache=key_cache,
        value_cache=value_cache,
        query_lens=query_lens,
        kv_lens=kv_lens,
        block_tables=block_tables,
        scale=scale,
        sliding_window=sliding_window,
319
        soft_cap=soft_cap,
320
    )
321
322
323
    atol, rtol = 1.5e-2, 1e-2
    if q_dtype is not None:
        atol, rtol = 1.5e-1, 1.5e-1
324
325
326
327
    (
        torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
        f"{torch.max(torch.abs(output - ref_output))}",
    )