"...crds/templates/nvidia.com_dynamographdeployments.yaml" did not exist on "1ab2fe1b0e2a28d99bbdd53a15925bcb6ef94464"
test_flash_attn.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional
5
6
7
8

import pytest
import torch

9
from vllm.platforms import current_platform
zhuwenwen's avatar
zhuwenwen committed
10

zhuwenwen's avatar
zhuwenwen committed
11

zhuwenwen's avatar
zhuwenwen committed
12
13
if current_platform.is_rocm():
    from flash_attn import flash_attn_varlen_func
zhuwenwen's avatar
zhuwenwen committed
14
else:
zhuwenwen's avatar
zhuwenwen committed
15
16
17
18
    from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
                                    flash_attn_varlen_func,
                                    flash_attn_with_kvcache,
                                    is_fa_version_supported)
19
20

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
21
22
23
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
24
QDTYPES = [None, torch.float8_e4m3fn]
25
26
27
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
28
29
30
31
32
33


def ref_paged_attn(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
34
35
    query_lens: list[int],
    kv_lens: list[int],
36
37
38
    block_tables: torch.Tensor,
    scale: float,
    sliding_window: Optional[int] = None,
39
    soft_cap: Optional[float] = None,
40
41
42
43
44
) -> torch.Tensor:
    num_seqs = len(query_lens)
    block_tables = block_tables.cpu().numpy()
    _, block_size, num_kv_heads, head_size = key_cache.shape

45
    outputs: list[torch.Tensor] = []
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    start_idx = 0
    for i in range(num_seqs):
        query_len = query_lens[i]
        kv_len = kv_lens[i]
        q = query[start_idx:start_idx + query_len]
        q *= scale

        num_kv_blocks = (kv_len + block_size - 1) // block_size
        block_indices = block_tables[i, :num_kv_blocks]

        k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
        k = k[:kv_len]
        v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
        v = v[:kv_len]

        if q.shape[1] != k.shape[1]:
            k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
            v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
        attn = torch.einsum("qhd,khd->hqk", q, k).float()
        empty_mask = torch.ones(query_len, kv_len)
        mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
        if sliding_window is not None:
            sliding_window_mask = torch.triu(empty_mask,
                                             diagonal=kv_len -
                                             (query_len + sliding_window) +
                                             1).bool().logical_not()
            mask |= sliding_window_mask
73
74
        if soft_cap is not None:
            attn = soft_cap * torch.tanh(attn / soft_cap)
75
76
77
78
79
80
81
82
83
        attn.masked_fill_(mask, float("-inf"))
        attn = torch.softmax(attn, dim=-1).to(v.dtype)
        out = torch.einsum("hqk,khd->qhd", attn, v)

        outputs.append(out)
        start_idx += query_len

    return torch.cat(outputs, dim=0)

zhuwenwen's avatar
zhuwenwen committed
84
85
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="flash_attn_with_paged_kv is not supported on ROCm.")
86
@pytest.mark.parametrize("use_out", [True, False])
87
88
89
90
91
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
92
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
93
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
94
@pytest.mark.parametrize("sliding_window", [None, 256])
95
@pytest.mark.parametrize("fa_version", [2, 3])
96
@pytest.mark.parametrize("q_dtype", QDTYPES)
97
@torch.inference_mode()
98
def test_flash_attn_with_paged_kv(
99
    use_out: bool,
100
101
    kv_lens: list[int],
    num_heads: tuple[int, int],
102
103
104
    head_size: int,
    dtype: torch.dtype,
    block_size: int,
105
    soft_cap: Optional[float],
106
    num_blocks: int,
107
    sliding_window: Optional[int],
108
    fa_version: int,
109
    q_dtype: Optional[torch.dtype],
110
111
) -> None:
    torch.set_default_device("cuda")
112
113
114
    if not is_fa_version_supported(fa_version):
        pytest.skip(f"Flash attention version {fa_version} not supported due "
                    f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
115
116
117
    if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
        pytest.skip("Flash attention with quantized inputs is only "
                    "supported on version 3 with bfloat16 base type")
118

119
    current_platform.seed_everything(0)
120
121
122
123
124
125
    num_seqs = len(kv_lens)
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_kv_len = max(kv_lens)
    scale = head_size**-0.5
126
127
    window_size = ((sliding_window - 1, 0) if sliding_window is not None else
                   (-1, -1))
128
129

    query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
130
    key_cache = torch.randn(num_blocks,
131
132
133
134
135
136
137
138
139
                            block_size,
                            num_kv_heads,
                            head_size,
                            dtype=dtype)
    value_cache = torch.randn_like(key_cache)
    kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
    block_tables = torch.randint(0,
140
                                 num_blocks,
141
142
143
                                 (num_seqs, max_num_blocks_per_seq),
                                 dtype=torch.int32)

144
145
    q = query.unsqueeze(1)
    out = torch.empty_like(q) if use_out else None
146
147
148
149
150
151
152
153
154

    maybe_quantized_query = q
    maybe_quantized_key_cache = key_cache
    maybe_quantized_value_cache = value_cache
    q_descale = None
    k_descale = None
    v_descale = None
    if q_dtype is not None:
        # QKV are drawn from N(0, 1): no need for a fp8 scaling factor
Happy's avatar
Happy committed
155
        maybe_quantized_query = q.to(q_dtype)
156
157
158
159
160
161
162
163
        maybe_quantized_key_cache = key_cache.to(q_dtype)
        maybe_quantized_value_cache = value_cache.to(q_dtype)

        scale_shape = (num_seqs, num_kv_heads)
        q_descale = torch.ones(scale_shape, dtype=torch.float32)
        k_descale = torch.ones(scale_shape, dtype=torch.float32)
        v_descale = torch.ones(scale_shape, dtype=torch.float32)

164
    output = flash_attn_with_kvcache(
165
166
167
        q=maybe_quantized_query,
        k_cache=maybe_quantized_key_cache,
        v_cache=maybe_quantized_value_cache,
168
        out=out,
169
170
171
172
        softmax_scale=scale,
        causal=True,
        block_table=block_tables,
        cache_seqlens=kv_lens_tensor,
173
        softcap=soft_cap if soft_cap is not None else 0,
174
        window_size=window_size,
175
        fa_version=fa_version,
176
177
178
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
179
180
181
    )
    output = output if not use_out else out
    output = output.squeeze(1)
182

183
184
185
186
    atol, rtol = 1.5e-2, 1e-2
    if q_dtype is not None:
        atol, rtol = 1.5e-1, 1.5e-1

187
188
189
190
191
192
193
194
195
    ref_output = ref_paged_attn(query=query,
                                key_cache=key_cache,
                                value_cache=value_cache,
                                query_lens=[1] * num_seqs,
                                kv_lens=kv_lens,
                                block_tables=block_tables,
                                scale=scale,
                                soft_cap=soft_cap,
                                sliding_window=sliding_window)
196
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
197
        f"{torch.max(torch.abs(output - ref_output))}"
zhuwenwen's avatar
zhuwenwen committed
198
        
199

zhuwenwen's avatar
zhuwenwen committed
200
201
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="varlen_with_paged_kv is not supported on ROCm.")
202
203
204
205
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens",
                         [[(1, 1328), (5, 18),
                           (129, 463)], [(1, 523), (1, 37), (1, 2011)]])
206
207
208
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
209
@pytest.mark.parametrize("sliding_window", [None, 256])
210
@pytest.mark.parametrize("dtype", DTYPES)
211
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
212
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
213
@pytest.mark.parametrize("fa_version", [2, 3])
214
@pytest.mark.parametrize("q_dtype", QDTYPES)
215
@torch.inference_mode()
216
def test_varlen_with_paged_kv(
217
    use_out: bool,
218
219
    seq_lens: list[tuple[int, int]],
    num_heads: tuple[int, int],
220
221
222
223
    head_size: int,
    sliding_window: Optional[int],
    dtype: torch.dtype,
    block_size: int,
224
    soft_cap: Optional[float],
225
    num_blocks: int,
226
    fa_version: int,
227
    q_dtype: Optional[torch.dtype],
228
229
) -> None:
    torch.set_default_device("cuda")
230
231
232
    if not is_fa_version_supported(fa_version):
        pytest.skip(f"Flash attention version {fa_version} not supported due "
                    f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
233
234
235
    if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
        pytest.skip("Flash attention with quantized inputs is only "
                    "supported on version 3 with bfloat16 base type")
236
    current_platform.seed_everything(0)
237
238
239
240
241
242
243
244
    num_seqs = len(seq_lens)
    query_lens = [x[0] for x in seq_lens]
    kv_lens = [x[1] for x in seq_lens]
    num_query_heads = num_heads[0]
    num_kv_heads = num_heads[1]
    assert num_query_heads % num_kv_heads == 0
    max_query_len = max(query_lens)
    max_kv_len = max(kv_lens)
245
    window_size = ((sliding_window - 1, 0) if sliding_window is not None else
246
247
248
249
250
251
252
                   (-1, -1))
    scale = head_size**-0.5

    query = torch.randn(sum(query_lens),
                        num_query_heads,
                        head_size,
                        dtype=dtype)
253
    key_cache = torch.randn(num_blocks,
254
255
256
257
258
259
260
261
                            block_size,
                            num_kv_heads,
                            head_size,
                            dtype=dtype)
    value_cache = torch.randn_like(key_cache)
    cu_query_lens = torch.tensor([0] + query_lens,
                                 dtype=torch.int32).cumsum(dim=0,
                                                           dtype=torch.int32)
262
    kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
263
264
265

    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
    block_tables = torch.randint(0,
266
                                 num_blocks,
267
268
269
                                 (num_seqs, max_num_blocks_per_seq),
                                 dtype=torch.int32)

270
    out = torch.empty_like(query) if use_out else None
zhuwenwen's avatar
zhuwenwen committed
271

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    maybe_quantized_query = query
    maybe_quantized_key_cache = key_cache
    maybe_quantized_value_cache = value_cache
    q_descale = None
    k_descale = None
    v_descale = None
    if q_dtype is not None:
        # QKV are drawn from N(0, 1): no need for a fp8 scaling factor
        maybe_quantized_query = query.to(q_dtype)
        maybe_quantized_key_cache = key_cache.to(q_dtype)
        maybe_quantized_value_cache = value_cache.to(q_dtype)

        scale_shape = (num_seqs, num_kv_heads)
        q_descale = torch.ones(scale_shape, dtype=torch.float32)
        k_descale = torch.ones(scale_shape, dtype=torch.float32)
        v_descale = torch.ones(scale_shape, dtype=torch.float32)

289
    output = flash_attn_varlen_func(
290
291
292
        q=maybe_quantized_query,
        k=maybe_quantized_key_cache,
        v=maybe_quantized_value_cache,
293
        out=out,
294
        cu_seqlens_q=cu_query_lens,
295
        seqused_k=kv_lens,
296
297
298
299
300
301
        max_seqlen_q=max_query_len,
        max_seqlen_k=max_kv_len,
        softmax_scale=scale,
        causal=True,
        window_size=window_size,
        block_table=block_tables,
302
        softcap=soft_cap if soft_cap is not None else 0,
303
        fa_version=fa_version,
304
305
306
        q_descale=q_descale,
        k_descale=k_descale,
        v_descale=v_descale,
307
    )
308
    output = output if not use_out else out
309

310
311
312
313
314
315
316
317
318
    ref_output = ref_paged_attn(
        query=query,
        key_cache=key_cache,
        value_cache=value_cache,
        query_lens=query_lens,
        kv_lens=kv_lens,
        block_tables=block_tables,
        scale=scale,
        sliding_window=sliding_window,
319
        soft_cap=soft_cap,
320
    )
321
322
323
324
    atol, rtol = 1.5e-2, 1e-2
    if q_dtype is not None:
        atol, rtol = 1.5e-1, 1.5e-1
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
zhuwenwen's avatar
zhuwenwen committed
325
        f"{torch.max(torch.abs(output - ref_output))}"