test_flash_attn.py 11.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional
5
6
7
8

import pytest
import torch

9
from vllm.platforms import current_platform
10
11
12
13
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
                                  flash_attn_varlen_func,
                                  flash_attn_with_kvcache,
                                  is_fa_version_supported)
14

15
NUM_HEADS = [(4, 4), (8, 2)]
16
HEAD_SIZES = [128, 256]
17
18
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
19
QDTYPES = [None, torch.float8_e4m3fn]
20
21
22
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
23
24
SOFT_CAPS = [None, 50.0]
SLIDING_WINDOWS = [None, 256]
25
26
27
28
29
30


def ref_paged_attn(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
31
32
    query_lens: list[int],
    kv_lens: list[int],
33
34
35
    block_tables: torch.Tensor,
    scale: float,
    sliding_window: Optional[int] = None,
36
    soft_cap: Optional[float] = None,
37
38
39
40
41
) -> torch.Tensor:
    num_seqs = len(query_lens)
    block_tables = block_tables.cpu().numpy()
    _, block_size, num_kv_heads, head_size = key_cache.shape

42
    outputs: list[torch.Tensor] = []
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    start_idx = 0
    for i in range(num_seqs):
        query_len = query_lens[i]
        kv_len = kv_lens[i]
        q = query[start_idx:start_idx + query_len]
        q *= scale

        num_kv_blocks = (kv_len + block_size - 1) // block_size
        block_indices = block_tables[i, :num_kv_blocks]

        k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
        k = k[:kv_len]
        v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
        v = v[:kv_len]

        if q.shape[1] != k.shape[1]:
            k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
            v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
        attn = torch.einsum("qhd,khd->hqk", q, k).float()
        empty_mask = torch.ones(query_len, kv_len)
        mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
        if sliding_window is not None:
            sliding_window_mask = torch.triu(empty_mask,
                                             diagonal=kv_len -
                                             (query_len + sliding_window) +
                                             1).bool().logical_not()
            mask |= sliding_window_mask
70
71
        if soft_cap is not None:
            attn = soft_cap * torch.tanh(attn / soft_cap)
72
73
74
75
76
77
78
79
80
81
        attn.masked_fill_(mask, float("-inf"))
        attn = torch.softmax(attn, dim=-1).to(v.dtype)
        out = torch.einsum("hqk,khd->qhd", attn, v)

        outputs.append(out)
        start_idx += query_len

    return torch.cat(outputs, dim=0)


82
@pytest.mark.parametrize("use_out", [True, False])
83
84
85
86
87
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
88
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
89
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
90
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
91
@pytest.mark.parametrize("fa_version", [2, 3])
92
@pytest.mark.parametrize("q_dtype", QDTYPES)
93
@torch.inference_mode()
94
def test_flash_attn_with_paged_kv(
95
    use_out: bool,
96
97
    kv_lens: list[int],
    num_heads: tuple[int, int],
98
99
100
    head_size: int,
    dtype: torch.dtype,
    block_size: int,
101
    soft_cap: Optional[float],
102
    num_blocks: int,
103
    sliding_window: Optional[int],
104
    fa_version: int,
105
    q_dtype: Optional[torch.dtype],
106
107
) -> None:
    torch.set_default_device("cuda")
108
109
110
    if not is_fa_version_supported(fa_version):
        pytest.skip(f"Flash attention version {fa_version} not supported due "
                    f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
111
112
113
    if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
        pytest.skip("Flash attention with quantized inputs is only "
                    "supported on version 3 with bfloat16 base type")
114

115
    current_platform.seed_everything(0)
116
117
118
119
120
121
    num_seqs = len(kv_lens)
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_kv_len = max(kv_lens)
    scale = head_size**-0.5
122
123
    window_size = ((sliding_window - 1, 0) if sliding_window is not None else
                   (-1, -1))
124
125

    query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
126
    key_cache = torch.randn(num_blocks,
127
128
129
130
131
132
133
134
135
                            block_size,
                            num_kv_heads,
                            head_size,
                            dtype=dtype)
    value_cache = torch.randn_like(key_cache)
    kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
    block_tables = torch.randint(0,
136
                                 num_blocks,
137
138
139
                                 (num_seqs, max_num_blocks_per_seq),
                                 dtype=torch.int32)

140
141
    q = query.unsqueeze(1)
    out = torch.empty_like(q) if use_out else None
142
143
144
145
146
147
148
149
150

    maybe_quantized_query = q
    maybe_quantized_key_cache = key_cache
    maybe_quantized_value_cache = value_cache
    q_descale = None
    k_descale = None
    v_descale = None
    if q_dtype is not None:
        # QKV are drawn from N(0, 1): no need for a fp8 scaling factor
Happy's avatar
Happy committed
151
        maybe_quantized_query = q.to(q_dtype)
152
153
154
155
156
157
158
159
        maybe_quantized_key_cache = key_cache.to(q_dtype)
        maybe_quantized_value_cache = value_cache.to(q_dtype)

        scale_shape = (num_seqs, num_kv_heads)
        q_descale = torch.ones(scale_shape, dtype=torch.float32)
        k_descale = torch.ones(scale_shape, dtype=torch.float32)
        v_descale = torch.ones(scale_shape, dtype=torch.float32)

160
    output = flash_attn_with_kvcache(
161
162
163
        q=maybe_quantized_query,
        k_cache=maybe_quantized_key_cache,
        v_cache=maybe_quantized_value_cache,
164
        out=out,
165
166
167
168
        softmax_scale=scale,
        causal=True,
        block_table=block_tables,
        cache_seqlens=kv_lens_tensor,
169
        softcap=soft_cap if soft_cap is not None else 0,
170
        window_size=window_size,
171
        fa_version=fa_version,
172
173
174
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
175
176
177
    )
    output = output if not use_out else out
    output = output.squeeze(1)
178

179
180
181
182
    atol, rtol = 1.5e-2, 1e-2
    if q_dtype is not None:
        atol, rtol = 1.5e-1, 1.5e-1

183
184
185
186
187
188
189
190
191
    ref_output = ref_paged_attn(query=query,
                                key_cache=key_cache,
                                value_cache=value_cache,
                                query_lens=[1] * num_seqs,
                                kv_lens=kv_lens,
                                block_tables=block_tables,
                                scale=scale,
                                soft_cap=soft_cap,
                                sliding_window=sliding_window)
192
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
193
194
195
        f"{torch.max(torch.abs(output - ref_output))}"


196
197
198
199
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens",
                         [[(1, 1328), (5, 18),
                           (129, 463)], [(1, 523), (1, 37), (1, 2011)]])
200
201
202
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
203
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
204
@pytest.mark.parametrize("dtype", DTYPES)
205
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
206
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
207
@pytest.mark.parametrize("fa_version", [2, 3])
208
@pytest.mark.parametrize("q_dtype", QDTYPES)
209
@torch.inference_mode()
210
def test_varlen_with_paged_kv(
211
    use_out: bool,
212
213
    seq_lens: list[tuple[int, int]],
    num_heads: tuple[int, int],
214
215
216
217
    head_size: int,
    sliding_window: Optional[int],
    dtype: torch.dtype,
    block_size: int,
218
    soft_cap: Optional[float],
219
    num_blocks: int,
220
    fa_version: int,
221
    q_dtype: Optional[torch.dtype],
222
223
) -> None:
    torch.set_default_device("cuda")
224
225
226
    if not is_fa_version_supported(fa_version):
        pytest.skip(f"Flash attention version {fa_version} not supported due "
                    f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
227
228
229
    if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
        pytest.skip("Flash attention with quantized inputs is only "
                    "supported on version 3 with bfloat16 base type")
230
    current_platform.seed_everything(0)
231
232
233
234
235
236
237
238
    num_seqs = len(seq_lens)
    query_lens = [x[0] for x in seq_lens]
    kv_lens = [x[1] for x in seq_lens]
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_query_len = max(query_lens)
    max_kv_len = max(kv_lens)
239
    window_size = ((sliding_window - 1, 0) if sliding_window is not None else
240
241
242
243
244
245
246
                   (-1, -1))
    scale = head_size**-0.5

    query = torch.randn(sum(query_lens),
                        num_query_heads,
                        head_size,
                        dtype=dtype)
247
    key_cache = torch.randn(num_blocks,
248
249
250
251
252
253
254
255
                            block_size,
                            num_kv_heads,
                            head_size,
                            dtype=dtype)
    value_cache = torch.randn_like(key_cache)
    cu_query_lens = torch.tensor([0] + query_lens,
                                 dtype=torch.int32).cumsum(dim=0,
                                                           dtype=torch.int32)
256
    kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
257
258
259

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
    block_tables = torch.randint(0,
260
                                 num_blocks,
261
262
263
                                 (num_seqs, max_num_blocks_per_seq),
                                 dtype=torch.int32)

264
    out = torch.empty_like(query) if use_out else None
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

    maybe_quantized_query = query
    maybe_quantized_key_cache = key_cache
    maybe_quantized_value_cache = value_cache
    q_descale = None
    k_descale = None
    v_descale = None
    if q_dtype is not None:
        # QKV are drawn from N(0, 1): no need for a fp8 scaling factor
        maybe_quantized_query = query.to(q_dtype)
        maybe_quantized_key_cache = key_cache.to(q_dtype)
        maybe_quantized_value_cache = value_cache.to(q_dtype)

        scale_shape = (num_seqs, num_kv_heads)
        q_descale = torch.ones(scale_shape, dtype=torch.float32)
        k_descale = torch.ones(scale_shape, dtype=torch.float32)
        v_descale = torch.ones(scale_shape, dtype=torch.float32)

283
    output = flash_attn_varlen_func(
284
285
286
        q=maybe_quantized_query,
        k=maybe_quantized_key_cache,
        v=maybe_quantized_value_cache,
287
        out=out,
288
        cu_seqlens_q=cu_query_lens,
289
        seqused_k=kv_lens,
290
291
292
293
294
295
        max_seqlen_q=max_query_len,
        max_seqlen_k=max_kv_len,
        softmax_scale=scale,
        causal=True,
        window_size=window_size,
        block_table=block_tables,
296
        softcap=soft_cap if soft_cap is not None else 0,
297
        fa_version=fa_version,
298
299
300
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
301
    )
302
    output = output if not use_out else out
303
304
305
306
307
308
309
310
311
312

    ref_output = ref_paged_attn(
        query=query,
        key_cache=key_cache,
        value_cache=value_cache,
        query_lens=query_lens,
        kv_lens=kv_lens,
        block_tables=block_tables,
        scale=scale,
        sliding_window=sliding_window,
313
        soft_cap=soft_cap,
314
    )
315
316
317
318
    atol, rtol = 1.5e-2, 1e-2
    if q_dtype is not None:
        atol, rtol = 1.5e-1, 1.5e-1
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
319
        f"{torch.max(torch.abs(output - ref_output))}"