test_flash_attn.py 9.3 KB
Newer Older
1
2
3
4
5
from typing import List, Optional, Tuple

import pytest
import torch

6
from vllm.platforms import current_platform
7
8
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
                                  flash_attn_with_kvcache)
9
10

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
11
12
13
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
14
15
16
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
17
18
19
20
21
22
23
24
25
26
27


def ref_paged_attn(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    query_lens: List[int],
    kv_lens: List[int],
    block_tables: torch.Tensor,
    scale: float,
    sliding_window: Optional[int] = None,
28
    soft_cap: Optional[float] = None,
29
30
31
32
33
) -> torch.Tensor:
    num_seqs = len(query_lens)
    block_tables = block_tables.cpu().numpy()
    _, block_size, num_kv_heads, head_size = key_cache.shape

34
    outputs: List[torch.Tensor] = []
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    start_idx = 0
    for i in range(num_seqs):
        query_len = query_lens[i]
        kv_len = kv_lens[i]
        q = query[start_idx:start_idx + query_len]
        q *= scale

        num_kv_blocks = (kv_len + block_size - 1) // block_size
        block_indices = block_tables[i, :num_kv_blocks]

        k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
        k = k[:kv_len]
        v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
        v = v[:kv_len]

        if q.shape[1] != k.shape[1]:
            k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
            v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
        attn = torch.einsum("qhd,khd->hqk", q, k).float()
        empty_mask = torch.ones(query_len, kv_len)
        mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
        if sliding_window is not None:
            sliding_window_mask = torch.triu(empty_mask,
                                             diagonal=kv_len -
                                             (query_len + sliding_window) +
                                             1).bool().logical_not()
            mask |= sliding_window_mask
62
63
        if soft_cap is not None:
            attn = soft_cap * torch.tanh(attn / soft_cap)
64
65
66
67
68
69
70
71
72
73
        attn.masked_fill_(mask, float("-inf"))
        attn = torch.softmax(attn, dim=-1).to(v.dtype)
        out = torch.einsum("hqk,khd->qhd", attn, v)

        outputs.append(out)
        start_idx += query_len

    return torch.cat(outputs, dim=0)


74
@pytest.mark.parametrize("use_out", [True, False])
75
76
77
78
79
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
80
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
81
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
82
@pytest.mark.parametrize("sliding_window", [None, 256])
83
@pytest.mark.parametrize("fa_version", [2, 3])
84
@torch.inference_mode()
85
def test_flash_attn_with_paged_kv(
86
    use_out: bool,
87
    kv_lens: List[int],
88
89
90
91
    num_heads: Tuple[int, int],
    head_size: int,
    dtype: torch.dtype,
    block_size: int,
92
    soft_cap: Optional[float],
93
    num_blocks: int,
94
    sliding_window: Optional[int],
95
    fa_version: int,
96
97
) -> None:
    torch.set_default_device("cuda")
98
99
100
101
102
    if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
                            or torch.cuda.get_device_capability() == (8, 9)):
        pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
                    "insufficient shared memory for some shapes")

103
    current_platform.seed_everything(0)
104
105
106
107
108
109
    num_seqs = len(kv_lens)
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_kv_len = max(kv_lens)
    scale = head_size**-0.5
110
111
    window_size = ((sliding_window - 1, 0) if sliding_window is not None else
                   (-1, -1))
112
113

    query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
114
    key_cache = torch.randn(num_blocks,
115
116
117
118
119
120
121
122
123
                            block_size,
                            num_kv_heads,
                            head_size,
                            dtype=dtype)
    value_cache = torch.randn_like(key_cache)
    kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
    block_tables = torch.randint(0,
124
                                 num_blocks,
125
126
127
                                 (num_seqs, max_num_blocks_per_seq),
                                 dtype=torch.int32)

128
129
    q = query.unsqueeze(1)
    out = torch.empty_like(q) if use_out else None
130
    output = flash_attn_with_kvcache(
131
        q=q,
132
133
        k_cache=key_cache,
        v_cache=value_cache,
134
        out=out,
135
136
137
138
        softmax_scale=scale,
        causal=True,
        block_table=block_tables,
        cache_seqlens=kv_lens_tensor,
139
        softcap=soft_cap if soft_cap is not None else 0,
140
        window_size=window_size,
141
        fa_version=fa_version,
142
143
144
    )
    output = output if not use_out else out
    output = output.squeeze(1)
145

146
147
148
149
150
151
152
153
154
    ref_output = ref_paged_attn(query=query,
                                key_cache=key_cache,
                                value_cache=value_cache,
                                query_lens=[1] * num_seqs,
                                kv_lens=kv_lens,
                                block_tables=block_tables,
                                scale=scale,
                                soft_cap=soft_cap,
                                sliding_window=sliding_window)
155
    torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
156
157
158
        f"{torch.max(torch.abs(output - ref_output))}"


159
160
161
162
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens",
                         [[(1, 1328), (5, 18),
                           (129, 463)], [(1, 523), (1, 37), (1, 2011)]])
163
164
165
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
166
@pytest.mark.parametrize("sliding_window", [None, 256])
167
@pytest.mark.parametrize("dtype", DTYPES)
168
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
169
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
170
@pytest.mark.parametrize("fa_version", [2, 3])
171
@torch.inference_mode()
172
def test_varlen_with_paged_kv(
173
    use_out: bool,
174
175
176
177
178
179
    seq_lens: List[Tuple[int, int]],
    num_heads: Tuple[int, int],
    head_size: int,
    sliding_window: Optional[int],
    dtype: torch.dtype,
    block_size: int,
180
    soft_cap: Optional[float],
181
    num_blocks: int,
182
    fa_version: int,
183
184
) -> None:
    torch.set_default_device("cuda")
185
186
187
188
189
    if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
                            or torch.cuda.get_device_capability() == (8, 9)):
        pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
                    "insufficient shared memory for some shapes")

190
    current_platform.seed_everything(0)
191
192
193
194
195
196
197
198
    num_seqs = len(seq_lens)
    query_lens = [x[0] for x in seq_lens]
    kv_lens = [x[1] for x in seq_lens]
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_query_len = max(query_lens)
    max_kv_len = max(kv_lens)
199
    window_size = ((sliding_window - 1, 0) if sliding_window is not None else
200
201
202
203
204
205
206
                   (-1, -1))
    scale = head_size**-0.5

    query = torch.randn(sum(query_lens),
                        num_query_heads,
                        head_size,
                        dtype=dtype)
207
    key_cache = torch.randn(num_blocks,
208
209
210
211
212
213
214
215
                            block_size,
                            num_kv_heads,
                            head_size,
                            dtype=dtype)
    value_cache = torch.randn_like(key_cache)
    cu_query_lens = torch.tensor([0] + query_lens,
                                 dtype=torch.int32).cumsum(dim=0,
                                                           dtype=torch.int32)
216
    kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
217
218
219

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
    block_tables = torch.randint(0,
220
                                 num_blocks,
221
222
223
                                 (num_seqs, max_num_blocks_per_seq),
                                 dtype=torch.int32)

224
    out = torch.empty_like(query) if use_out else None
225
    output = flash_attn_varlen_func(
226
227
228
        q=query,
        k=key_cache,
        v=value_cache,
229
        out=out,
230
        cu_seqlens_q=cu_query_lens,
231
        seqused_k=kv_lens,
232
233
234
235
236
237
        max_seqlen_q=max_query_len,
        max_seqlen_k=max_kv_len,
        softmax_scale=scale,
        causal=True,
        window_size=window_size,
        block_table=block_tables,
238
        softcap=soft_cap if soft_cap is not None else 0,
239
        fa_version=fa_version,
240
    )
241
    output = output if not use_out else out
242
243
244
245
246
247
248
249
250
251

    ref_output = ref_paged_attn(
        query=query,
        key_cache=key_cache,
        value_cache=value_cache,
        query_lens=query_lens,
        kv_lens=kv_lens,
        block_tables=block_tables,
        scale=scale,
        sliding_window=sliding_window,
252
        soft_cap=soft_cap,
253
    )
254
    torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
255
        f"{torch.max(torch.abs(output - ref_output))}"