test_prefix_prefill.py 22.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
6
import random
import time
7
from collections.abc import Callable
8

9
import pytest
10
import torch
11
import torch.nn.functional as F
12

13
from vllm.platforms import current_platform
14
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed
15
16
17
18
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
    chunked_prefill_paged_decode,
)
from vllm.v1.attention.ops.prefix_prefill import context_attention_fwd
19

zhuwenwen's avatar
zhuwenwen committed
20
if not current_platform.is_rocm():
zhuwenwen's avatar
zhuwenwen committed
21
22
23
24
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
    from vllm.attention.backends.xformers import _make_alibi_bias

25
NUM_HEADS = [64]
26
27
NUM_QUERIES_PER_KV = [1, 64]
HEAD_SIZES = [24, 128]
28
DTYPES = [torch.float16]
29
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
30
SLIDING_WINDOW = [0, 16, 2048]
31
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] if not current_platform.is_rocm() else ["auto"]
32

33
34
OPS = [chunked_prefill_paged_decode, context_attention_fwd]

35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def create_causal_attention_mask_for_sdpa(
    query_lens: list[int],
    seq_lens: list[int],
    sliding_window: int = 0,
    device: torch.device = None,
    dtype: torch.dtype = None,
) -> torch.Tensor:
    total_queries = sum(query_lens)
    total_keys = sum(seq_lens)

    # Create a mask filled with -inf
    mask = torch.full(
        (total_queries, total_keys), float("-inf"), device=device, dtype=dtype
    )

    query_start = 0
    key_start = 0

    for query_len, seq_len in zip(query_lens, seq_lens):
        query_end = query_start + query_len
        key_end = key_start + seq_len
        q_indices = torch.arange(query_len, device=device)
        k_indices = torch.arange(seq_len, device=device)
        q_pos_in_seq = seq_len - query_len + q_indices

        valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None]

        if sliding_window > 0:
            valid_mask &= k_indices[None, :] >= (
                q_pos_in_seq[:, None] - sliding_window + 1
            )

        mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0

        query_start = query_end
        key_start = key_end

    return mask


def create_alibi_causal_mask(
    query_len: int,
    seq_len: int,
    alibi_slopes: torch.Tensor,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    query_pos = torch.arange(
        seq_len - query_len, seq_len, device=device, dtype=torch.float32
    )
    key_pos = torch.arange(seq_len, device=device, dtype=torch.float32)

    rel_pos = key_pos[None, :] - query_pos[:, None]

    # Apply ALiBi slopes: [num_heads, query_len, seq_len]
    alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :]
    alibi_bias = alibi_bias.to(dtype)

    # Apply causal mask: prevent attending to future positions
    # causal_mask[i, j] = True if key_pos[j] <= query_pos[i]
    causal_mask = key_pos[None, :] <= query_pos[:, None]
    alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf"))

    # Add batch dimension: [1, num_heads, query_len, seq_len]
    # SDPA expects batch dimension even for single sequences
    return alibi_bias.unsqueeze(0)


104
@pytest.mark.parametrize("num_heads", NUM_HEADS)
105
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
106
107
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
108
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
109
@pytest.mark.parametrize("device", CUDA_DEVICES)
110
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
111
@pytest.mark.parametrize("op", OPS)
112
113
114
@torch.inference_mode()
def test_contexted_kv_attention(
    num_heads: int,
115
    num_queries_per_kv: int,
116
    head_size: int,
117
    sliding_window: int,
118
    dtype: torch.dtype,
119
    kv_cache_dtype: str,
120
    device: str,
121
    op: Callable,
122
    block_size: int = 32,
123
) -> None:
124
    if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
125
        pytest.skip(
126
127
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
128

129
130
131
132
133
134
    if (
        current_platform.is_rocm()
        and op is chunked_prefill_paged_decode
        and kv_cache_dtype == "fp8_e5m2"
    ):
        pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
135

136
    set_random_seed(0)
137
    torch.set_default_device(device)
138

139
140
    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
141
142
143
144
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

145
146
147
148
149
    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    max_block_per_request = 64
150
    query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
151
152
153
    # ensure one sequence in batch is a decode
    query_lens[-1] = 1

154
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
155
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
156
    num_kv_heads = num_heads // num_queries_per_kv
157

158
    num_tokens = sum(query_lens)
159
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
160
    query.uniform_(-1e-3, 1e-3)
161
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
162

163
    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
164
165
166
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)

167
168
169
170
    if kv_cache_dtype == "auto":
        cache_dtype = dtype
    else:
        cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
171
172
173
174
175
176
    k_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
    v_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
177
178
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
179
    values = torch.arange(0, cache_size, dtype=torch.int32)
180
    values = values[torch.randperm(cache_size)]
181
    block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
182
183
    b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
184
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
185
186
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
187
188
    b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
        torch.int32
189
    )
190
    for i in range(BS):
191
        for j in range(query_lens[i]):
192
193
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
194
195
196
197
198
199
200
201
202
203
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
204
205
206
207
208
209
            k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                key[start_loc:end_loc]
            )
            v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                value[start_loc:end_loc]
            )
210
211
212
213
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
214
215
216
217
218
    k_cache = (
        k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
        .permute(0, 2, 3, 1, 4)
        .contiguous()
    )
219
220
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
221
222
223
224
225
    v_cache = (
        v_cache.view(-1, block_size, num_kv_heads, head_size)
        .permute(0, 2, 3, 1)
        .contiguous()
    )
226
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
227

228
229
    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        sliding_window=sliding_window,
    )
247
248
    torch.cuda.synchronize()
    start_time = time.time()
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        sliding_window=sliding_window,
    )
266
267
    torch.cuda.synchronize()
    end_time = time.time()
268
    print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
269
270
271

    scale = float(1.0 / (head_size**0.5))

272
273
274
275
276
277
    # Reshape for SDPA: (seq_len, num_heads, head_size) ->
    # (1, num_heads, seq_len, head_size)
    query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size)
    query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape(
        1, num_heads, num_tokens, head_size
    )
278

279
280
281
282
283
284
285
    # Expand key and value for GQA/MQA to match query heads
    key_sdpa = key[:, :, None, :].expand(
        key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
    )
    key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape(
        1, num_heads, sum(seq_lens), head_size
    )
286

287
288
    value_sdpa = value[:, :, None, :].expand(
        value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
289
    )
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape(
        1, num_heads, sum(seq_lens), head_size
    )

    attn_mask = create_causal_attention_mask_for_sdpa(
        query_lens, seq_lens, sliding_window, device=device, dtype=dtype
    )

    output_ref = F.scaled_dot_product_attention(
        query_sdpa,
        key_sdpa,
        value_sdpa,
        attn_mask=attn_mask,
        dropout_p=0.0,
304
305
306
307
        scale=scale,
    )
    torch.cuda.synchronize()
    start_time = time.time()
308
309
310
311
312
313
    output_ref = F.scaled_dot_product_attention(
        query_sdpa,
        key_sdpa,
        value_sdpa,
        attn_mask=attn_mask,
        dropout_p=0.0,
314
315
316
317
        scale=scale,
    )
    torch.cuda.synchronize()
    end_time = time.time()
318
319
320
321
322
    print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")

    # Reshape output back to (num_tokens, num_heads, head_size)
    output_ref = output_ref.view(num_heads, num_tokens, head_size)
    output_ref = output_ref.permute(1, 0, 2).contiguous()
323
    atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
324
    torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
325
326
327
328
329
330


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
331
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
332
@pytest.mark.parametrize("device", CUDA_DEVICES)
333
@pytest.mark.parametrize("op", OPS)
334
335
336
337
338
339
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    dtype: torch.dtype,
340
    kv_cache_dtype: str,
341
    device: str,
342
    op: Callable,
343
    block_size: int = 32,
344
) -> None:
345
    if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
346
        pytest.skip(
347
348
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
349

350
351
352
353
354
355
    if (
        current_platform.is_rocm()
        and op is chunked_prefill_paged_decode
        and kv_cache_dtype == "fp8_e5m2"
    ):
        pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
356

357
    set_random_seed(0)
358
359
360
361
362
363
364
365
366
367
    torch.set_default_device(device)

    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

    def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
        # Fork from: vllm/vllm/model_executor/models/bloom.py#L44
368
        closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
369
        base = torch.tensor(
370
            2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
371
372
373
374
375
376
377
            dtype=torch.float32,
        )
        powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
        slopes = torch.pow(base, powers)

        if closest_power_of_2 != total_num_heads:
            extra_base = torch.tensor(
378
                2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
379
380
                dtype=torch.float32,
            )
381
382
383
384
385
386
387
            num_remaining_heads = min(
                closest_power_of_2, total_num_heads - closest_power_of_2
            )
            extra_powers = torch.arange(
                start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
            )
            slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        return slopes

    alibi_slopes = _get_alibi_slopes(num_heads).to(device)

    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    max_block_per_request = 64
    query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
    num_kv_heads = num_heads // num_queries_per_kv

    num_tokens = sum(query_lens)
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
    query.uniform_(-1e-3, 1e-3)
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)

    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)
410
411
412
413
    if kv_cache_dtype == "auto":
        cache_dtype = dtype
    else:
        cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
414
415
416
417
418
419
    k_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
    v_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
420
421
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
422
    values = torch.arange(0, cache_size, dtype=torch.int32)
423
    values = values[torch.randperm(cache_size)]
424
    block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
425
426
    b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
427
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
428
429
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
430
431
    b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
        torch.int32
432
    )
433
434
    for i in range(BS):
        for j in range(query_lens[i]):
435
436
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
437
438
439
440
441
442
443
444
445
446
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
447
448
449
450
451
452
            k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                key[start_loc:end_loc]
            )
            v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                value[start_loc:end_loc]
            )
453
454
455
456
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
457
458
459
460
461
    k_cache = (
        k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
        .permute(0, 2, 3, 1, 4)
        .contiguous()
    )
462
463
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
464
465
466
467
468
    v_cache = (
        v_cache.view(-1, block_size, num_kv_heads, head_size)
        .permute(0, 2, 3, 1)
        .contiguous()
    )
469
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
470
471
472

    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        alibi_slopes=alibi_slopes,
    )
490
491
    torch.cuda.synchronize()
    start_time = time.time()
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        alibi_slopes=alibi_slopes,
    )
509
510
    torch.cuda.synchronize()
    end_time = time.time()
511

512
    print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
513
514
    scale = float(1.0 / (head_size**0.5))

515
516
517
518
519
520
521
522
523
    # Prepare query, key, value for SDPA
    # Expand key and value for GQA/MQA to match query heads
    key_expanded = key[:, :, None, :].expand(
        key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
    )
    value_expanded = value[:, :, None, :].expand(
        value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
    )

524
    output_ref = torch.empty_like(output)
525
526

    torch.cuda.synchronize()
527
    start_time = time.time()
528

529
    query_start = 0
530
    key_start = 0
531
532
    for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
        query_end = query_start + query_len
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        key_end = key_start + seq_len

        # Get query, key, value for this sequence
        q = query[query_start:query_end]  # [query_len, num_heads, head_size]
        k = key_expanded[
            key_start:key_end
        ]  # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
        v = value_expanded[
            key_start:key_end
        ]  # [seq_len, num_kv_heads, num_queries_per_kv, head_size]

        # Reshape for SDPA: (batch=1, num_heads, seq_len, head_size)
        q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size)
        q_sdpa = (
            q_sdpa.permute(1, 2, 0, 3)
            .reshape(1, num_heads, query_len, head_size)
            .contiguous()
        )

        k_sdpa = (
            k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
        )
        v_sdpa = (
            v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
557
        )
558
559
560
561
562
563
564
565
566
567
568
569
570
571

        # Create ALiBi causal mask for this sequence using utility function
        alibi_mask = create_alibi_causal_mask(
            query_len, seq_len, alibi_slopes, device, dtype
        )

        # Compute attention
        out = F.scaled_dot_product_attention(
            q_sdpa,
            k_sdpa,
            v_sdpa,
            attn_mask=alibi_mask,
            dropout_p=0.0,
            scale=scale,
572
        )
573
574
575
576
577
578
579
580

        # Reshape output back to [query_len, num_heads, head_size]
        out = out.view(num_heads, query_len, head_size).permute(1, 0, 2)
        output_ref[query_start:query_end].copy_(out)

        query_start = query_end
        key_start = key_end

581
582
    torch.cuda.synchronize()
    end_time = time.time()
583
    print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
584
585
    atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
    torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602


# These tests are optional to only run when explicitly invoked
#
# pytest -v -s --optional \
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
#
# These tests are useful to test model dtype float32 on Turing devices.
# We skip them to not increase the time when running tests on CI
@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
603
@pytest.mark.parametrize("op", OPS)
604
605
606
607
608
609
610
611
612
@torch.inference_mode()
def test_contexted_kv_attention_f32(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    sliding_window: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
613
    op: Callable,
614
) -> None:
615
616
617
618
619
620
621
622
623
624
    test_contexted_kv_attention(
        num_heads,
        num_queries_per_kv,
        head_size,
        sliding_window,
        dtype,
        kv_cache_dtype,
        device,
        op,
    )
625
626
627
628
629
630
631
632
633


@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
634
@pytest.mark.parametrize("op", OPS)
635
636
637
638
639
640
641
642
@torch.inference_mode()
def test_contexted_kv_attention_alibi_f32(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
643
    op: Callable,
644
) -> None:
645
646
647
    test_contexted_kv_attention_alibi(
        num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
    )
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678


@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_qwen3_nonstandard_block_size(
    head_size: int,
    dtype: torch.dtype,
    device: str,
    op: Callable,
) -> None:
    """
    A separate test function specifically added
    for Qwen3-Next-80B (Block Size 544).
    """
    if not current_platform.is_rocm():
        pytest.skip("544 block size optimization is only for ROCm.")

    test_contexted_kv_attention(
        num_heads=64,
        num_queries_per_kv=1,
        head_size=head_size,
        block_size=544,
        sliding_window=0,
        dtype=dtype,
        kv_cache_dtype="auto",
        device=device,
        op=op,
    )