"vllm/vscode:/vscode.git/clone" did not exist on "6ebaf9ac71387228951fd1642662a020080eb037"
test_prefix_prefill.py 21.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
6
import random
import time
7
from collections.abc import Callable
8

9
import pytest
10
import torch
11
import torch.nn.functional as F
12

13
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
14
from vllm.attention.ops.prefix_prefill import context_attention_fwd
15
from vllm.platforms import current_platform
16
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
17

18
NUM_HEADS = [64]
19
20
NUM_QUERIES_PER_KV = [1, 64]
HEAD_SIZES = [24, 128]
21
DTYPES = [torch.float16]
22
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
23
SLIDING_WINDOW = [0, 16, 2048]
24
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
25

26
27
OPS = [chunked_prefill_paged_decode, context_attention_fwd]

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def create_causal_attention_mask_for_sdpa(
    query_lens: list[int],
    seq_lens: list[int],
    sliding_window: int = 0,
    device: torch.device = None,
    dtype: torch.dtype = None,
) -> torch.Tensor:
    total_queries = sum(query_lens)
    total_keys = sum(seq_lens)

    # Create a mask filled with -inf
    mask = torch.full(
        (total_queries, total_keys), float("-inf"), device=device, dtype=dtype
    )

    query_start = 0
    key_start = 0

    for query_len, seq_len in zip(query_lens, seq_lens):
        query_end = query_start + query_len
        key_end = key_start + seq_len
        q_indices = torch.arange(query_len, device=device)
        k_indices = torch.arange(seq_len, device=device)
        q_pos_in_seq = seq_len - query_len + q_indices

        valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None]

        if sliding_window > 0:
            valid_mask &= k_indices[None, :] >= (
                q_pos_in_seq[:, None] - sliding_window + 1
            )

        mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0

        query_start = query_end
        key_start = key_end

    return mask


def create_alibi_causal_mask(
    query_len: int,
    seq_len: int,
    alibi_slopes: torch.Tensor,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    query_pos = torch.arange(
        seq_len - query_len, seq_len, device=device, dtype=torch.float32
    )
    key_pos = torch.arange(seq_len, device=device, dtype=torch.float32)

    rel_pos = key_pos[None, :] - query_pos[:, None]

    # Apply ALiBi slopes: [num_heads, query_len, seq_len]
    alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :]
    alibi_bias = alibi_bias.to(dtype)

    # Apply causal mask: prevent attending to future positions
    # causal_mask[i, j] = True if key_pos[j] <= query_pos[i]
    causal_mask = key_pos[None, :] <= query_pos[:, None]
    alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf"))

    # Add batch dimension: [1, num_heads, query_len, seq_len]
    # SDPA expects batch dimension even for single sequences
    return alibi_bias.unsqueeze(0)


97
@pytest.mark.parametrize("num_heads", NUM_HEADS)
98
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
99
100
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
101
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
102
@pytest.mark.parametrize("device", CUDA_DEVICES)
103
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
104
@pytest.mark.parametrize("op", OPS)
105
106
107
@torch.inference_mode()
def test_contexted_kv_attention(
    num_heads: int,
108
    num_queries_per_kv: int,
109
    head_size: int,
110
    sliding_window: int,
111
    dtype: torch.dtype,
112
    kv_cache_dtype: str,
113
    device: str,
114
    op: Callable,
115
) -> None:
116
    if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
117
        pytest.skip(
118
119
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
120

121
122
123
124
125
126
127
    if (
        current_platform.is_rocm()
        and op is chunked_prefill_paged_decode
        and kv_cache_dtype == "fp8_e5m2"
    ):
        pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")

128
    current_platform.seed_everything(0)
129
    torch.set_default_device(device)
130

131
132
    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
133
134
135
136
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

137
138
139
140
141
142
    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    block_size = 32
    max_block_per_request = 64
143
    query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
144
145
146
    # ensure one sequence in batch is a decode
    query_lens[-1] = 1

147
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
148
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
149
    num_kv_heads = num_heads // num_queries_per_kv
150

151
    num_tokens = sum(query_lens)
152
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
153
    query.uniform_(-1e-3, 1e-3)
154
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
155

156
    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
157
158
159
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)

160
161
162
163
    if kv_cache_dtype == "auto":
        cache_dtype = dtype
    else:
        cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
164
165
166
167
168
169
    k_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
    v_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
170
171
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
172
    values = torch.arange(0, cache_size, dtype=torch.int32)
173
    values = values[torch.randperm(cache_size)]
174
    block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
175
176
177
    b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
178
179
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
180
    b_seq_start_loc = torch.cumsum(
181
        torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
182
    )
183
    for i in range(BS):
184
        for j in range(query_lens[i]):
185
186
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
187
188
189
190
191
192
193
194
195
196
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
197
198
199
200
201
202
            k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                key[start_loc:end_loc]
            )
            v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                value[start_loc:end_loc]
            )
203
204
205
206
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
207
208
209
210
211
    k_cache = (
        k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
        .permute(0, 2, 3, 1, 4)
        .contiguous()
    )
212
213
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
214
215
216
217
218
    v_cache = (
        v_cache.view(-1, block_size, num_kv_heads, head_size)
        .permute(0, 2, 3, 1)
        .contiguous()
    )
219
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
220

221
222
    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        sliding_window=sliding_window,
    )
240
241
    torch.cuda.synchronize()
    start_time = time.time()
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        sliding_window=sliding_window,
    )
259
260
    torch.cuda.synchronize()
    end_time = time.time()
261
    print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
262
263
264

    scale = float(1.0 / (head_size**0.5))

265
266
267
268
269
270
    # Reshape for SDPA: (seq_len, num_heads, head_size) ->
    # (1, num_heads, seq_len, head_size)
    query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size)
    query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape(
        1, num_heads, num_tokens, head_size
    )
271

272
273
274
275
276
277
278
    # Expand key and value for GQA/MQA to match query heads
    key_sdpa = key[:, :, None, :].expand(
        key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
    )
    key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape(
        1, num_heads, sum(seq_lens), head_size
    )
279

280
281
    value_sdpa = value[:, :, None, :].expand(
        value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
282
    )
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape(
        1, num_heads, sum(seq_lens), head_size
    )

    attn_mask = create_causal_attention_mask_for_sdpa(
        query_lens, seq_lens, sliding_window, device=device, dtype=dtype
    )

    output_ref = F.scaled_dot_product_attention(
        query_sdpa,
        key_sdpa,
        value_sdpa,
        attn_mask=attn_mask,
        dropout_p=0.0,
297
298
299
300
        scale=scale,
    )
    torch.cuda.synchronize()
    start_time = time.time()
301
302
303
304
305
306
    output_ref = F.scaled_dot_product_attention(
        query_sdpa,
        key_sdpa,
        value_sdpa,
        attn_mask=attn_mask,
        dropout_p=0.0,
307
308
309
310
        scale=scale,
    )
    torch.cuda.synchronize()
    end_time = time.time()
311
312
313
314
315
    print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")

    # Reshape output back to (num_tokens, num_heads, head_size)
    output_ref = output_ref.view(num_heads, num_tokens, head_size)
    output_ref = output_ref.permute(1, 0, 2).contiguous()
316
    atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
317
    torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
318
319
320
321
322
323


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
324
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
325
@pytest.mark.parametrize("device", CUDA_DEVICES)
326
@pytest.mark.parametrize("op", OPS)
327
328
329
330
331
332
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    dtype: torch.dtype,
333
    kv_cache_dtype: str,
334
    device: str,
335
    op: Callable,
336
) -> None:
337
    if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
338
        pytest.skip(
339
340
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
341

342
343
344
345
346
347
348
    if (
        current_platform.is_rocm()
        and op is chunked_prefill_paged_decode
        and kv_cache_dtype == "fp8_e5m2"
    ):
        pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")

349
    current_platform.seed_everything(0)
350
351
352
353
354
355
356
357
358
359
    torch.set_default_device(device)

    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

    def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
        # Fork from: vllm/vllm/model_executor/models/bloom.py#L44
360
        closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
361
        base = torch.tensor(
362
            2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
363
364
365
366
367
368
369
            dtype=torch.float32,
        )
        powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
        slopes = torch.pow(base, powers)

        if closest_power_of_2 != total_num_heads:
            extra_base = torch.tensor(
370
                2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
371
372
                dtype=torch.float32,
            )
373
374
375
376
377
378
379
            num_remaining_heads = min(
                closest_power_of_2, total_num_heads - closest_power_of_2
            )
            extra_powers = torch.arange(
                start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
            )
            slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        return slopes

    alibi_slopes = _get_alibi_slopes(num_heads).to(device)

    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    block_size = 32
    max_block_per_request = 64
    query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
    num_kv_heads = num_heads // num_queries_per_kv

    num_tokens = sum(query_lens)
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
    query.uniform_(-1e-3, 1e-3)
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)

    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)
403
404
405
406
    if kv_cache_dtype == "auto":
        cache_dtype = dtype
    else:
        cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
407
408
409
410
411
412
    k_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
    v_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
413
414
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
415
    values = torch.arange(0, cache_size, dtype=torch.int32)
416
    values = values[torch.randperm(cache_size)]
417
    block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
418
419
420
    b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
421
422
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
423
    b_seq_start_loc = torch.cumsum(
424
        torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
425
    )
426
427
    for i in range(BS):
        for j in range(query_lens[i]):
428
429
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
430
431
432
433
434
435
436
437
438
439
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
440
441
442
443
444
445
            k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                key[start_loc:end_loc]
            )
            v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                value[start_loc:end_loc]
            )
446
447
448
449
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
450
451
452
453
454
    k_cache = (
        k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
        .permute(0, 2, 3, 1, 4)
        .contiguous()
    )
455
456
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
457
458
459
460
461
    v_cache = (
        v_cache.view(-1, block_size, num_kv_heads, head_size)
        .permute(0, 2, 3, 1)
        .contiguous()
    )
462
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
463
464
465

    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        alibi_slopes=alibi_slopes,
    )
483
484
    torch.cuda.synchronize()
    start_time = time.time()
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        alibi_slopes=alibi_slopes,
    )
502
503
    torch.cuda.synchronize()
    end_time = time.time()
504
    print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
505
506
    scale = float(1.0 / (head_size**0.5))

507
508
509
510
511
512
513
514
515
    # Prepare query, key, value for SDPA
    # Expand key and value for GQA/MQA to match query heads
    key_expanded = key[:, :, None, :].expand(
        key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
    )
    value_expanded = value[:, :, None, :].expand(
        value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
    )

516
    output_ref = torch.empty_like(output)
517
518

    torch.cuda.synchronize()
519
    start_time = time.time()
520
521
522

    query_start = 0
    key_start = 0
523
524
    for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
        query_end = query_start + query_len
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        key_end = key_start + seq_len

        # Get query, key, value for this sequence
        q = query[query_start:query_end]  # [query_len, num_heads, head_size]
        k = key_expanded[
            key_start:key_end
        ]  # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
        v = value_expanded[
            key_start:key_end
        ]  # [seq_len, num_kv_heads, num_queries_per_kv, head_size]

        # Reshape for SDPA: (batch=1, num_heads, seq_len, head_size)
        q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size)
        q_sdpa = (
            q_sdpa.permute(1, 2, 0, 3)
            .reshape(1, num_heads, query_len, head_size)
            .contiguous()
        )

        k_sdpa = (
            k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
        )
        v_sdpa = (
            v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
549
        )
550
551
552
553
554
555
556
557
558
559
560
561
562
563

        # Create ALiBi causal mask for this sequence using utility function
        alibi_mask = create_alibi_causal_mask(
            query_len, seq_len, alibi_slopes, device, dtype
        )

        # Compute attention
        out = F.scaled_dot_product_attention(
            q_sdpa,
            k_sdpa,
            v_sdpa,
            attn_mask=alibi_mask,
            dropout_p=0.0,
            scale=scale,
564
        )
565
566
567
568
569
570
571
572

        # Reshape output back to [query_len, num_heads, head_size]
        out = out.view(num_heads, query_len, head_size).permute(1, 0, 2)
        output_ref[query_start:query_end].copy_(out)

        query_start = query_end
        key_start = key_end

573
574
    torch.cuda.synchronize()
    end_time = time.time()
575
    print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
576
577
    atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
    torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594


# These tests are optional to only run when explicitly invoked
#
# pytest -v -s --optional \
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
#
# These tests are useful to test model dtype float32 on Turing devices.
# We skip them to not increase the time when running tests on CI
@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
595
@pytest.mark.parametrize("op", OPS)
596
597
598
599
600
601
602
603
604
@torch.inference_mode()
def test_contexted_kv_attention_f32(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    sliding_window: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
605
    op: Callable,
606
) -> None:
607
608
609
610
611
612
613
614
615
616
    test_contexted_kv_attention(
        num_heads,
        num_queries_per_kv,
        head_size,
        sliding_window,
        dtype,
        kv_cache_dtype,
        device,
        op,
    )
617
618
619
620
621
622
623
624
625


@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
626
@pytest.mark.parametrize("op", OPS)
627
628
629
630
631
632
633
634
@torch.inference_mode()
def test_contexted_kv_attention_alibi_f32(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
635
    op: Callable,
636
) -> None:
637
638
639
    test_contexted_kv_attention_alibi(
        num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
    )