test_prefix_prefill.py 24.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import math
4
5
import random
import time
6
from collections.abc import Callable
7

8
import pytest
9
10
import torch

11
from vllm.attention.backends.xformers import _make_alibi_bias
12
13
from vllm.attention.ops.chunked_prefill_paged_decode import (
    chunked_prefill_paged_decode)
14
from vllm.attention.ops.prefix_prefill import context_attention_fwd
15
16
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
17

zhuwenwen's avatar
zhuwenwen committed
18
if not current_platform.is_rocm():
zhuwenwen's avatar
zhuwenwen committed
19
20
21
22
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
    from vllm.attention.backends.xformers import _make_alibi_bias

23
24
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
25
HEAD_SIZES = [128, 96, 24]
26
DTYPES = [torch.float16]
27
28
29
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
30
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
zhuwenwen's avatar
zhuwenwen committed
31
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] if not current_platform() else ["auto"]
32

33
34
OPS = [chunked_prefill_paged_decode, context_attention_fwd]

35
36

@pytest.mark.parametrize("num_heads", NUM_HEADS)
37
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
38
39
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
40
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
41
@pytest.mark.parametrize("device", CUDA_DEVICES)
42
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
43
@pytest.mark.parametrize("op", OPS)
44
45
46
@torch.inference_mode()
def test_contexted_kv_attention(
    num_heads: int,
47
    num_queries_per_kv: int,
48
    head_size: int,
49
    sliding_window: int,
50
    dtype: torch.dtype,
51
    kv_cache_dtype: str,
52
    device: str,
53
    op: Callable,
54
) -> None:
55
56
57
58
59
60
61

    if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
            89):
        pytest.skip(
            'Triton limitation: fp8e4nv data type is not supported on CUDA'
            ' arch < 89')

62
    current_platform.seed_everything(0)
63
    torch.set_default_device(device)
64

65
66
    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
67
68
69
70
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

71
72
73
74
75
76
    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    block_size = 32
    max_block_per_request = 64
77
    query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
78
79
80
    # ensure one sequence in batch is a decode
    query_lens[-1] = 1

81
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
82
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
83
    num_kv_heads = num_heads // num_queries_per_kv
84

85
    num_tokens = sum(query_lens)
86
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
87
    query.uniform_(-1e-3, 1e-3)
88
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
89

90
    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
91
92
93
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)

94
95
96
97
    if kv_cache_dtype == "auto":
        cache_dtype = dtype
    else:
        cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
98
99
    k_cache = torch.zeros(cache_size,
                          block_size,
100
                          num_kv_heads,
101
                          head_size,
102
                          dtype=cache_dtype)
103
104
    v_cache = torch.zeros(cache_size,
                          block_size,
105
                          num_kv_heads,
106
                          head_size,
107
                          dtype=cache_dtype)
108
109
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
110
    values = torch.arange(0, cache_size, dtype=torch.long)
111
112
113
    values = values[torch.randperm(cache_size)]
    block_table = values[:BS * max_block_per_request].view(
        BS, max_block_per_request)
114
115
    b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
116
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
117
                                            dtype=torch.long),
118
119
120
121
                               dim=0)
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
    b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
122
                                                dtype=torch.long),
123
124
                                   dim=0)
    for i in range(BS):
125
        for j in range(query_lens[i]):
126
127
128
129
130
131
132
133
134
135
136
137
138
139
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
                                            j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
                                              b_ctx_len[i] + j])
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
140
141
142
143
144
145
            k_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             key[start_loc:end_loc])
            v_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             value[start_loc:end_loc])
146
147
148
149
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
150
    k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
151
152
153
                           8).permute(0, 2, 3, 1, 4).contiguous()
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
154
    v_cache = v_cache.view(-1, block_size, num_kv_heads,
155
                           head_size).permute(0, 2, 3, 1).contiguous()
156
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
157

158
159
    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
160
161
162
163
164
165
166
167
168
169
    op(query,
       k,
       v,
       output,
       kv_cache_dtype,
       k_cache,
       v_cache,
       block_table,
       b_start_loc,
       b_seq_len,
170
       MAX_CTX_LEN,
171
172
173
174
       max_input_len,
       k_scale,
       v_scale,
       sliding_window=sliding_window)
175
176
    torch.cuda.synchronize()
    start_time = time.time()
177
178
179
180
181
182
183
184
185
186
    op(query,
       k,
       v,
       output,
       kv_cache_dtype,
       k_cache,
       v_cache,
       block_table,
       b_start_loc,
       b_seq_len,
187
       MAX_CTX_LEN,
188
189
190
191
       max_input_len,
       k_scale,
       v_scale,
       sliding_window=sliding_window)
192
193
194
195
    torch.cuda.synchronize()
    end_time = time.time()
    print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")

zhuwenwen's avatar
zhuwenwen committed
196
    if not current_platform.is_rocm():
197
        scale = float(1.0 / (head_size**0.5))
198

199
        attn_op = xops.fmha.cutlass.FwOp()
200

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        if num_kv_heads != num_heads:
            # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
            # project the key and value tensors to the desired number of
            # heads.
            #
            # see also: vllm/model_executor/layers/attention.py
            query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
                            query.shape[-1])
            key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
                                            num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                        None, :].expand(value.shape[0], num_kv_heads,
                                        num_queries_per_kv, value.shape[-1])
        query = query.unsqueeze(0)
        key = key.unsqueeze(0)
        value = value.unsqueeze(0)
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
            query_lens, seq_lens)
        if sliding_window > 0:
            attn_bias = attn_bias.make_local_attention_from_bottomright(
                sliding_window)
        output_ref = xops.memory_efficient_attention_forward(
            query,
            key,
            value,
            attn_bias=attn_bias,
            p=0.0,
            scale=scale,
            op=attn_op,
        )
        torch.cuda.synchronize()
        start_time = time.time()
        output_ref = xops.memory_efficient_attention_forward(
            query,
            key,
            value,
            attn_bias=attn_bias,
            p=0.0,
            scale=scale,
            op=attn_op,
        )
        torch.cuda.synchronize()
        end_time = time.time()
        print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
        output_ref = output_ref.reshape(output.shape)
zhuwenwen's avatar
zhuwenwen committed
247
        atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
248
        torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
249
250
251
252
253
254


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
255
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
256
@pytest.mark.parametrize("device", CUDA_DEVICES)
257
@pytest.mark.parametrize("op", OPS)
258
259
260
261
262
263
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    dtype: torch.dtype,
264
    kv_cache_dtype: str,
265
    device: str,
266
    op: Callable,
267
) -> None:
268
269
270
271
272
273
274

    if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
            89):
        pytest.skip(
            'Triton limitation: fp8e4nv data type is not supported on CUDA'
            ' arch < 89')

275
    current_platform.seed_everything(0)
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    torch.set_default_device(device)

    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

    def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
        # Fork from: vllm/vllm/model_executor/models/bloom.py#L44
        closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
        base = torch.tensor(
            2**(-(2**-(math.log2(closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
        slopes = torch.pow(base, powers)

        if closest_power_of_2 != total_num_heads:
            extra_base = torch.tensor(
                2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
                dtype=torch.float32,
            )
            num_remaining_heads = min(closest_power_of_2,
                                      total_num_heads - closest_power_of_2)
            extra_powers = torch.arange(start=1,
                                        end=1 + 2 * num_remaining_heads,
                                        step=2,
                                        dtype=torch.int32)
            slopes = torch.cat(
                [slopes, torch.pow(extra_base, extra_powers)], dim=0)
        return slopes

    alibi_slopes = _get_alibi_slopes(num_heads).to(device)

    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    block_size = 32
    max_block_per_request = 64
    query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
    num_kv_heads = num_heads // num_queries_per_kv

    num_tokens = sum(query_lens)
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
    query.uniform_(-1e-3, 1e-3)
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)

    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)
330
331
332
333
    if kv_cache_dtype == "auto":
        cache_dtype = dtype
    else:
        cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
334
335
336
337
    k_cache = torch.zeros(cache_size,
                          block_size,
                          num_kv_heads,
                          head_size,
338
                          dtype=cache_dtype)
339
340
341
342
    v_cache = torch.zeros(cache_size,
                          block_size,
                          num_kv_heads,
                          head_size,
343
                          dtype=cache_dtype)
344
345
346
347
348
349
350
351
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    values = torch.arange(0, cache_size, dtype=torch.long)
    values = values[torch.randperm(cache_size)]
    block_table = values[:BS * max_block_per_request].view(
        BS, max_block_per_request)
    b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
352
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
                                            dtype=torch.long),
                               dim=0)
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
    b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
                                                dtype=torch.long),
                                   dim=0)
    for i in range(BS):
        for j in range(query_lens[i]):
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
                                            j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
                                              b_ctx_len[i] + j])
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
            k_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             key[start_loc:end_loc])
            v_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             value[start_loc:end_loc])
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
    k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
                           8).permute(0, 2, 3, 1, 4).contiguous()
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
    v_cache = v_cache.view(-1, block_size, num_kv_heads,
                           head_size).permute(0, 2, 3, 1).contiguous()
392
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
393
394
395

    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
396
397
398
399
400
401
402
403
404
405
    op(query,
       k,
       v,
       output,
       kv_cache_dtype,
       k_cache,
       v_cache,
       block_table,
       b_start_loc,
       b_seq_len,
406
       MAX_CTX_LEN,
407
408
409
410
       max_input_len,
       k_scale,
       v_scale,
       alibi_slopes=alibi_slopes)
411
412
    torch.cuda.synchronize()
    start_time = time.time()
413
414
415
416
417
418
419
420
421
422
    op(query,
       k,
       v,
       output,
       kv_cache_dtype,
       k_cache,
       v_cache,
       block_table,
       b_start_loc,
       b_seq_len,
423
       MAX_CTX_LEN,
424
425
426
427
       max_input_len,
       k_scale,
       v_scale,
       alibi_slopes=alibi_slopes)
428
429
430
    torch.cuda.synchronize()
    end_time = time.time()
    print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
431
    
zhuwenwen's avatar
zhuwenwen committed
432
    if not current_platform():
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        scale = float(1.0 / (head_size**0.5))

        # NOTE(DefTruth): In order to reuse _make_alibi_bias function,
        # we have to pad query tensor before MQA/GQA expanding.
        if query.shape[0] != key.shape[0]:
            query_pad = torch.empty(sum(seq_lens),
                                    num_heads,
                                    head_size,
                                    dtype=dtype)
            query_pad.uniform_(-1e-3, 1e-3)
            seq_start = 0
            query_start = 0
            for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
                seq_end = seq_start + seq_len
                query_end = query_start + query_len
                query_pad[seq_start:seq_end, ...] = torch.cat([
                    torch.zeros(
                        seq_len - query_len, num_heads, head_size, dtype=dtype),
                    query[query_start:query_end, ...]
                ],
                                                            dim=0)
                seq_start += seq_len
                query_start += query_len
            query = query_pad

        if num_kv_heads != num_heads:
            # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
            # project the key and value tensors to the desired number of
            # heads.
            #
            # see also: vllm/model_executor/layers/attention.py
            query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
                            query.shape[-1])
            key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
                                            num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                        None, :].expand(value.shape[0], num_kv_heads,
                                        num_queries_per_kv, value.shape[-1])

        query = query.unsqueeze(0)
        key = key.unsqueeze(0)
        value = value.unsqueeze(0)
475

476
477
        attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
        output_ref = torch.empty_like(output)
478
479
        seq_start = 0
        query_start = 0
zhuwenwen's avatar
zhuwenwen committed
480
481
482
483
484
485
        start_time = time.time()
        # Attention with alibi slopes.
        # FIXME(DefTruth): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        # modified from: vllm/attention/backends/xformers.py#L343
486
487
488
        for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
            seq_end = seq_start + seq_len
            query_end = query_start + query_len
zhuwenwen's avatar
zhuwenwen committed
489
490
491
492
493
494
495
496
497
498
499
500
501
            out = xops.memory_efficient_attention_forward(query[:,
                                                                seq_start:seq_end],
                                                        key[:,
                                                            seq_start:seq_end],
                                                        value[:,
                                                                seq_start:seq_end],
                                                        attn_bias=attn_bias[i],
                                                        p=0.0,
                                                        scale=scale)
            out = out.view_as(query[:, seq_start:seq_end]).view(
                seq_len, num_heads, head_size)
            output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
                                                            ...])
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            seq_start += seq_len
            query_start += query_len
        query = query_pad

    if num_kv_heads != num_heads:
        # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
        # project the key and value tensors to the desired number of
        # heads.
        #
        # see also: vllm/model_executor/layers/attention.py
        key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
                                        num_queries_per_kv, key.shape[-1])
        value = value[:, :,
                      None, :].expand(value.shape[0], num_kv_heads,
                                      num_queries_per_kv, value.shape[-1])
517
518
519
520
521
        # [seq, num_kv_heads, num_queries_per_kv, dk]=>
        # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
        # codebase. We save some time reshaping alibi matrix at runtime.
        key = key.reshape(key.shape[0], -1, key.shape[-1])
        value = value.reshape(value.shape[0], -1, value.shape[-1])
522
523
524
525
526
527
528
529
    query = query.unsqueeze(0)
    key = key.unsqueeze(0)
    value = value.unsqueeze(0)

    attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
    output_ref = torch.empty_like(output)
    seq_start = 0
    query_start = 0
zhuwenwen's avatar
zhuwenwen committed
530
531
    
    if not current_platform():
532
533
534
535
536
537
        start_time = time.time()
        # Attention with alibi slopes.
        # FIXME(DefTruth): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        # modified from: vllm/attention/backends/xformers.py#L343
538
539
540
        for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
            seq_end = seq_start + seq_len
            query_end = query_start + query_len
541
542
543
544
545
546
547
548
549
550
551
552
553
            out = xops.memory_efficient_attention_forward(query[:,
                                                                seq_start:seq_end],
                                                        key[:,
                                                            seq_start:seq_end],
                                                        value[:,
                                                                seq_start:seq_end],
                                                        attn_bias=attn_bias[i],
                                                        p=0.0,
                                                        scale=scale)
            out = out.view_as(query[:, seq_start:seq_end]).view(
                seq_len, num_heads, head_size)
            output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
                                                            ...])
554
555
            seq_start += seq_len
            query_start += query_len
zhuwenwen's avatar
zhuwenwen committed
556
557
558
559
        torch.cuda.synchronize()
        end_time = time.time()
        print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
        atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
zhuwenwen's avatar
zhuwenwen committed
560
        torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577


# These tests are optional to only run when explicitly invoked
#
# pytest -v -s --optional \
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
#
# These tests are useful to test model dtype float32 on Turing devices.
# We skip them to not increase the time when running tests on CI
@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
578
@pytest.mark.parametrize("op", OPS)
579
580
581
582
583
584
585
586
587
@torch.inference_mode()
def test_contexted_kv_attention_f32(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    sliding_window: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
588
    op: Callable,
589
590
) -> None:
    test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
591
592
                                sliding_window, dtype, kv_cache_dtype, device,
                                op)
593
594
595
596
597
598
599
600
601


@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
602
@pytest.mark.parametrize("op", OPS)
603
604
605
606
607
608
609
610
@torch.inference_mode()
def test_contexted_kv_attention_alibi_f32(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
611
    op: Callable,
612
613
) -> None:
    test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
614
                                      dtype, kv_cache_dtype, device, op)