test_flash_attn.py 8.16 KB
Newer Older
1
2
3
4
5
from typing import List, Optional, Tuple

import pytest
import torch

6
from vllm.utils import seed_everything
7
8
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
                                  flash_attn_with_kvcache)
9
10

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
11
12
13
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
14
15
16
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
17
18
19
20
21
22
23
24
25
26
27


def ref_paged_attn(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    query_lens: List[int],
    kv_lens: List[int],
    block_tables: torch.Tensor,
    scale: float,
    sliding_window: Optional[int] = None,
28
    soft_cap: Optional[float] = None,
29
30
31
32
33
) -> torch.Tensor:
    num_seqs = len(query_lens)
    block_tables = block_tables.cpu().numpy()
    _, block_size, num_kv_heads, head_size = key_cache.shape

34
    outputs: List[torch.Tensor] = []
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    start_idx = 0
    for i in range(num_seqs):
        query_len = query_lens[i]
        kv_len = kv_lens[i]
        q = query[start_idx:start_idx + query_len]
        q *= scale

        num_kv_blocks = (kv_len + block_size - 1) // block_size
        block_indices = block_tables[i, :num_kv_blocks]

        k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
        k = k[:kv_len]
        v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
        v = v[:kv_len]

        if q.shape[1] != k.shape[1]:
            k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
            v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
        attn = torch.einsum("qhd,khd->hqk", q, k).float()
        empty_mask = torch.ones(query_len, kv_len)
        mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
        if sliding_window is not None:
            sliding_window_mask = torch.triu(empty_mask,
                                             diagonal=kv_len -
                                             (query_len + sliding_window) +
                                             1).bool().logical_not()
            mask |= sliding_window_mask
62
63
        if soft_cap is not None:
            attn = soft_cap * torch.tanh(attn / soft_cap)
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        attn.masked_fill_(mask, float("-inf"))
        attn = torch.softmax(attn, dim=-1).to(v.dtype)
        out = torch.einsum("hqk,khd->qhd", attn, v)

        outputs.append(out)
        start_idx += query_len

    return torch.cat(outputs, dim=0)


@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
79
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
80
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
81
@pytest.mark.parametrize("sliding_window", [None, 256])
82
@torch.inference_mode()
83
def test_flash_attn_with_paged_kv(
84
    kv_lens: List[int],
85
86
87
88
    num_heads: Tuple[int, int],
    head_size: int,
    dtype: torch.dtype,
    block_size: int,
89
    soft_cap: Optional[float],
90
    num_blocks: int,
91
    sliding_window: Optional[int],
92
93
) -> None:
    torch.set_default_device("cuda")
94
    seed_everything(0)
95
96
97
98
99
100
    num_seqs = len(kv_lens)
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_kv_len = max(kv_lens)
    scale = head_size**-0.5
101
102
    window_size = ((sliding_window - 1, 0) if sliding_window is not None else
                   (-1, -1))
103
104

    query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
105
    key_cache = torch.randn(num_blocks,
106
107
108
109
110
111
112
113
114
                            block_size,
                            num_kv_heads,
                            head_size,
                            dtype=dtype)
    value_cache = torch.randn_like(key_cache)
    kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
    block_tables = torch.randint(0,
115
                                 num_blocks,
116
117
118
                                 (num_seqs, max_num_blocks_per_seq),
                                 dtype=torch.int32)

119
120
121
122
    output = flash_attn_with_kvcache(
        q=query.unsqueeze(1),
        k_cache=key_cache,
        v_cache=value_cache,
123
124
125
126
        softmax_scale=scale,
        causal=True,
        block_table=block_tables,
        cache_seqlens=kv_lens_tensor,
127
        softcap=soft_cap if soft_cap is not None else 0,
128
        window_size=window_size,
129
130
    ).squeeze(1)

131
132
133
134
135
136
137
138
139
    ref_output = ref_paged_attn(query=query,
                                key_cache=key_cache,
                                value_cache=value_cache,
                                query_lens=[1] * num_seqs,
                                kv_lens=kv_lens,
                                block_tables=block_tables,
                                scale=scale,
                                soft_cap=soft_cap,
                                sliding_window=sliding_window)
140
    torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
141
142
143
144
145
146
147
        f"{torch.max(torch.abs(output - ref_output))}"


@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
148
@pytest.mark.parametrize("sliding_window", [None, 256])
149
@pytest.mark.parametrize("dtype", DTYPES)
150
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
151
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
152
@torch.inference_mode()
153
154
155
156
157
158
159
def test_varlen_with_paged_kv(
    seq_lens: List[Tuple[int, int]],
    num_heads: Tuple[int, int],
    head_size: int,
    sliding_window: Optional[int],
    dtype: torch.dtype,
    block_size: int,
160
    soft_cap: Optional[float],
161
    num_blocks: int,
162
163
) -> None:
    torch.set_default_device("cuda")
164
    seed_everything(0)
165
166
167
168
169
170
171
172
    num_seqs = len(seq_lens)
    query_lens = [x[0] for x in seq_lens]
    kv_lens = [x[1] for x in seq_lens]
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_query_len = max(query_lens)
    max_kv_len = max(kv_lens)
173
    window_size = ((sliding_window - 1, 0) if sliding_window is not None else
174
175
176
177
178
179
180
                   (-1, -1))
    scale = head_size**-0.5

    query = torch.randn(sum(query_lens),
                        num_query_heads,
                        head_size,
                        dtype=dtype)
181
    key_cache = torch.randn(num_blocks,
182
183
184
185
186
187
188
189
190
191
192
193
194
195
                            block_size,
                            num_kv_heads,
                            head_size,
                            dtype=dtype)
    value_cache = torch.randn_like(key_cache)
    cu_query_lens = torch.tensor([0] + query_lens,
                                 dtype=torch.int32).cumsum(dim=0,
                                                           dtype=torch.int32)
    cu_kv_lens = torch.tensor([0] + kv_lens,
                              dtype=torch.int32).cumsum(dim=0,
                                                        dtype=torch.int32)

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
    block_tables = torch.randint(0,
196
                                 num_blocks,
197
198
199
                                 (num_seqs, max_num_blocks_per_seq),
                                 dtype=torch.int32)

200
    output = flash_attn_varlen_func(
201
202
203
204
205
206
207
208
209
210
211
        q=query,
        k=key_cache,
        v=value_cache,
        cu_seqlens_q=cu_query_lens,
        cu_seqlens_k=cu_kv_lens,
        max_seqlen_q=max_query_len,
        max_seqlen_k=max_kv_len,
        softmax_scale=scale,
        causal=True,
        window_size=window_size,
        block_table=block_tables,
212
        softcap=soft_cap if soft_cap is not None else 0,
213
214
215
216
217
218
219
220
221
222
223
    )

    ref_output = ref_paged_attn(
        query=query,
        key_cache=key_cache,
        value_cache=value_cache,
        query_lens=query_lens,
        kv_lens=kv_lens,
        block_tables=block_tables,
        scale=scale,
        sliding_window=sliding_window,
224
        soft_cap=soft_cap,
225
    )
226
    torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
227
        f"{torch.max(torch.abs(output - ref_output))}"