test_prefix_prefill.py 19.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
6
import random
import time
7
from collections.abc import Callable
8

9
import pytest
10
11
12
13
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

14
from tests.kernels.utils import make_alibi_bias
15
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
16
from vllm.attention.ops.prefix_prefill import context_attention_fwd
17
18
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
19

20
NUM_HEADS = [64]
21
22
NUM_QUERIES_PER_KV = [1, 64]
HEAD_SIZES = [24, 128]
23
DTYPES = [torch.float16]
24
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
25
SLIDING_WINDOW = [0, 16, 2048]
26
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
27

28
29
OPS = [chunked_prefill_paged_decode, context_attention_fwd]

30
31

@pytest.mark.parametrize("num_heads", NUM_HEADS)
32
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
33
34
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
35
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
36
@pytest.mark.parametrize("device", CUDA_DEVICES)
37
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
38
@pytest.mark.parametrize("op", OPS)
39
40
41
@torch.inference_mode()
def test_contexted_kv_attention(
    num_heads: int,
42
    num_queries_per_kv: int,
43
    head_size: int,
44
    sliding_window: int,
45
    dtype: torch.dtype,
46
    kv_cache_dtype: str,
47
    device: str,
48
    op: Callable,
49
) -> None:
50
    if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
51
        pytest.skip(
52
53
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
54

55
    current_platform.seed_everything(0)
56
    torch.set_default_device(device)
57

58
59
    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
60
61
62
63
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

64
65
66
67
68
69
    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    block_size = 32
    max_block_per_request = 64
70
    query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
71
72
73
    # ensure one sequence in batch is a decode
    query_lens[-1] = 1

74
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
75
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
76
    num_kv_heads = num_heads // num_queries_per_kv
77

78
    num_tokens = sum(query_lens)
79
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
80
    query.uniform_(-1e-3, 1e-3)
81
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
82

83
    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
84
85
86
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)

87
88
89
90
    if kv_cache_dtype == "auto":
        cache_dtype = dtype
    else:
        cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
91
92
93
94
95
96
    k_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
    v_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
97
98
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
99
    values = torch.arange(0, cache_size, dtype=torch.long)
100
    values = values[torch.randperm(cache_size)]
101
    block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
102
103
    b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
104
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
105
106
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
107
108
109
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
    )
110
    for i in range(BS):
111
        for j in range(query_lens[i]):
112
113
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
114
115
116
117
118
119
120
121
122
123
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
124
125
126
127
128
129
            k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                key[start_loc:end_loc]
            )
            v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                value[start_loc:end_loc]
            )
130
131
132
133
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
134
135
136
137
138
    k_cache = (
        k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
        .permute(0, 2, 3, 1, 4)
        .contiguous()
    )
139
140
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
141
142
143
144
145
    v_cache = (
        v_cache.view(-1, block_size, num_kv_heads, head_size)
        .permute(0, 2, 3, 1)
        .contiguous()
    )
146
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
147

148
149
    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        sliding_window=sliding_window,
    )
167
168
    torch.cuda.synchronize()
    start_time = time.time()
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        sliding_window=sliding_window,
    )
186
187
    torch.cuda.synchronize()
    end_time = time.time()
188
    print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
189
190
191
192
193

    scale = float(1.0 / (head_size**0.5))

    attn_op = xops.fmha.cutlass.FwOp()

194
195
196
197
198
199
    if num_kv_heads != num_heads:
        # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
        # project the key and value tensors to the desired number of
        # heads.
        #
        # see also: vllm/model_executor/layers/attention.py
200
201
202
203
204
205
206
207
208
        query = query.view(
            query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1]
        )
        key = key[:, :, None, :].expand(
            key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
        )
        value = value[:, :, None, :].expand(
            value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
        )
209
210
211
212
    query = query.unsqueeze(0)
    key = key.unsqueeze(0)
    value = value.unsqueeze(0)

213
    attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
214
215
        query_lens, seq_lens
    )
216
    if sliding_window > 0:
217
        attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window)
218
    output_ref = xops.memory_efficient_attention_forward(
219
220
221
        query,
        key,
        value,
222
223
224
225
226
227
228
229
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
        op=attn_op,
    )
    torch.cuda.synchronize()
    start_time = time.time()
    output_ref = xops.memory_efficient_attention_forward(
230
231
232
        query,
        key,
        value,
233
234
235
236
237
238
239
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
        op=attn_op,
    )
    torch.cuda.synchronize()
    end_time = time.time()
240
    print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
241
    output_ref = output_ref.reshape(output.shape)
242
    atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
243
    torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
244
245
246
247
248
249


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
250
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
251
@pytest.mark.parametrize("device", CUDA_DEVICES)
252
@pytest.mark.parametrize("op", OPS)
253
254
255
256
257
258
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    dtype: torch.dtype,
259
    kv_cache_dtype: str,
260
    device: str,
261
    op: Callable,
262
) -> None:
263
    if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
264
        pytest.skip(
265
266
            "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
        )
267

268
    current_platform.seed_everything(0)
269
270
271
272
273
274
275
276
277
278
    torch.set_default_device(device)

    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

    def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
        # Fork from: vllm/vllm/model_executor/models/bloom.py#L44
279
        closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
280
        base = torch.tensor(
281
            2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
282
283
284
285
286
287
288
            dtype=torch.float32,
        )
        powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
        slopes = torch.pow(base, powers)

        if closest_power_of_2 != total_num_heads:
            extra_base = torch.tensor(
289
                2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
290
291
                dtype=torch.float32,
            )
292
293
294
295
296
297
298
            num_remaining_heads = min(
                closest_power_of_2, total_num_heads - closest_power_of_2
            )
            extra_powers = torch.arange(
                start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
            )
            slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        return slopes

    alibi_slopes = _get_alibi_slopes(num_heads).to(device)

    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    block_size = 32
    max_block_per_request = 64
    query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
    num_kv_heads = num_heads // num_queries_per_kv

    num_tokens = sum(query_lens)
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
    query.uniform_(-1e-3, 1e-3)
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)

    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)
322
323
324
325
    if kv_cache_dtype == "auto":
        cache_dtype = dtype
    else:
        cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
326
327
328
329
330
331
    k_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
    v_cache = torch.zeros(
        cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
    )
332
333
334
335
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    values = torch.arange(0, cache_size, dtype=torch.long)
    values = values[torch.randperm(cache_size)]
336
    block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
337
338
    b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
339
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
340
341
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
342
343
344
    b_seq_start_loc = torch.cumsum(
        torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
    )
345
346
    for i in range(BS):
        for j in range(query_lens[i]):
347
348
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
349
350
351
352
353
354
355
356
357
358
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
359
360
361
362
363
364
            k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                key[start_loc:end_loc]
            )
            v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
                value[start_loc:end_loc]
            )
365
366
367
368
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
369
370
371
372
373
    k_cache = (
        k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
        .permute(0, 2, 3, 1, 4)
        .contiguous()
    )
374
375
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
376
377
378
379
380
    v_cache = (
        v_cache.view(-1, block_size, num_kv_heads, head_size)
        .permute(0, 2, 3, 1)
        .contiguous()
    )
381
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
382
383
384

    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        alibi_slopes=alibi_slopes,
    )
402
403
    torch.cuda.synchronize()
    start_time = time.time()
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    op(
        query,
        k,
        v,
        output,
        kv_cache_dtype,
        k_cache,
        v_cache,
        block_table,
        b_start_loc,
        b_seq_len,
        MAX_CTX_LEN,
        max_input_len,
        k_scale,
        v_scale,
        alibi_slopes=alibi_slopes,
    )
421
422
    torch.cuda.synchronize()
    end_time = time.time()
423
    print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
424
425
426
427
428
    scale = float(1.0 / (head_size**0.5))

    # NOTE(DefTruth): In order to reuse _make_alibi_bias function,
    # we have to pad query tensor before MQA/GQA expanding.
    if query.shape[0] != key.shape[0]:
429
        query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype)
430
431
432
433
434
435
        query_pad.uniform_(-1e-3, 1e-3)
        seq_start = 0
        query_start = 0
        for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
            seq_end = seq_start + seq_len
            query_end = query_start + query_len
436
437
438
439
440
441
442
            query_pad[seq_start:seq_end, ...] = torch.cat(
                [
                    torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype),
                    query[query_start:query_end, ...],
                ],
                dim=0,
            )
443
444
445
446
447
448
449
450
451
452
            seq_start += seq_len
            query_start += query_len
        query = query_pad

    if num_kv_heads != num_heads:
        # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
        # project the key and value tensors to the desired number of
        # heads.
        #
        # see also: vllm/model_executor/layers/attention.py
453
454
455
456
457
458
        key = key[:, :, None, :].expand(
            key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
        )
        value = value[:, :, None, :].expand(
            value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
        )
459
460
461
462
463
        # [seq, num_kv_heads, num_queries_per_kv, dk]=>
        # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
        # codebase. We save some time reshaping alibi matrix at runtime.
        key = key.reshape(key.shape[0], -1, key.shape[-1])
        value = value.reshape(value.shape[0], -1, value.shape[-1])
464
465
466
467
    query = query.unsqueeze(0)
    key = key.unsqueeze(0)
    value = value.unsqueeze(0)

468
    attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
469
470
471
472
473
474
475
476
    output_ref = torch.empty_like(output)
    seq_start = 0
    query_start = 0
    start_time = time.time()
    # Attention with alibi slopes.
    # FIXME(DefTruth): Because xformers does not support dynamic sequence
    # lengths with custom attention bias, we process each prompt one by
    # one. This is inefficient, especially when we have many short prompts.
477
    # modified from: vllm/v1/attention/backends/xformers.py#L343
478
479
480
    for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
        seq_end = seq_start + seq_len
        query_end = query_start + query_len
481
482
483
484
485
486
487
488
        out = xops.memory_efficient_attention_forward(
            query[:, seq_start:seq_end],
            key[:, seq_start:seq_end],
            value[:, seq_start:seq_end],
            attn_bias=attn_bias[i],
            p=0.0,
            scale=scale,
        )
489
        out = out.view_as(query[:, seq_start:seq_end]).view(
490
491
492
            seq_len, num_heads, head_size
        )
        output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...])
493
494
495
496
        seq_start += seq_len
        query_start += query_len
    torch.cuda.synchronize()
    end_time = time.time()
497
    print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
498
499
    atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
    torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516


# These tests are optional to only run when explicitly invoked
#
# pytest -v -s --optional \
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
#
# These tests are useful to test model dtype float32 on Turing devices.
# We skip them to not increase the time when running tests on CI
@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
517
@pytest.mark.parametrize("op", OPS)
518
519
520
521
522
523
524
525
526
@torch.inference_mode()
def test_contexted_kv_attention_f32(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    sliding_window: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
527
    op: Callable,
528
) -> None:
529
530
531
532
533
534
535
536
537
538
    test_contexted_kv_attention(
        num_heads,
        num_queries_per_kv,
        head_size,
        sliding_window,
        dtype,
        kv_cache_dtype,
        device,
        op,
    )
539
540
541
542
543
544
545
546
547


@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
548
@pytest.mark.parametrize("op", OPS)
549
550
551
552
553
554
555
556
@torch.inference_mode()
def test_contexted_kv_attention_alibi_f32(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
557
    op: Callable,
558
) -> None:
559
560
561
    test_contexted_kv_attention_alibi(
        num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
    )