test_attention.py 19 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
6
import torch

7
from tests.kernels.utils import opcheck
8
from vllm import _custom_ops as ops
9
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
10

11
12
from .allclose_default import get_default_atol, get_default_rtol

13
14
15
16
if not is_hip():
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

17
18
19
20
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
21
22
23
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
24
PARTITION_SIZE = 512
25
26
27
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
          ] if not is_hip() else [torch.half, torch.bfloat16]
28
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
29
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
30
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
31
32
33

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
Joe's avatar
Joe committed
34
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
35
36
              ] if not is_hip() else [64, 80, 96, 112, 128]

37
BLOCK_SIZES = [16, 32]
38
USE_ALIBI = [False, True]
39
KV_CACHE_DTYPE = ["auto", "fp8"]
40
SEEDS = [0]
41
42
43
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
44

45
46
47
48
49
50
51
52

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
53
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
54
    if attn_mask is not None:
55
56
57
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
58
59
60
61
62
63
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
64
    num_queries_per_kv: int,
65
66
67
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
68
    seq_lens: torch.Tensor,
69
70
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
71
) -> None:
72
73
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
74
75
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
76
    num_seqs = query.shape[0]
77

78
79
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
80
    for i in range(num_seqs):
81
        q = query[i].unsqueeze(0)
82
83
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
84

85
86
        keys_lst: List[torch.Tensor] = []
        values_lst: List[torch.Tensor] = []
87
        for j in range(seq_len):
88
89
90
91
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
92
            k = k.reshape(num_kv_heads, head_size)
93
            keys_lst.append(k)
94
95

            v = value_cache[block_number, :, :, block_offset]
96
97
98
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
99
100
101
102
103
104
105
106
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
107
108
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
109
110
111
112
113
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
114
115
116
        output[i].copy_(out, non_blocking=True)


117
@pytest.mark.parametrize("version", ["v1", "v2"])
118
119
120
121
122
123
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
124
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
125
@pytest.mark.parametrize("seed", SEEDS)
126
@pytest.mark.parametrize("device", CUDA_DEVICES)
127
def test_paged_attention(
128
    kv_cache_factory,
129
    version: str,
130
131
    num_seqs: int,
    num_heads: Tuple[int, int],
132
    head_size: int,
133
    use_alibi: bool,
134
135
    block_size: int,
    dtype: torch.dtype,
136
    kv_cache_dtype: str,
137
    seed: int,
138
    device: str,
139
) -> None:
Joe's avatar
Joe committed
140
141
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
142
143

    seed_everything(seed)
144
    torch.set_default_device(device)
145
146
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
147
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
148
149
150
151
152
153
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
154
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
155

156
157
158
159
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
160

161
    # Create the block tables.
162
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
163
    block_tables_lst: List[List[int]] = []
164
    for _ in range(num_seqs):
165
        block_table = [
166
            random.randint(0, NUM_BLOCKS - 1)
167
168
            for _ in range(max_num_blocks_per_seq)
        ]
169
170
171
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
172

173
174
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
175
176
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
177
                                                device)
178
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
179

180
    # Using default kv_scale
181
    k_scale = v_scale = 1.0
182

183
184
    # Call the paged attention kernel.
    output = torch.empty_like(query)
185
    if version == "v1":
186
        ops.paged_attention_v1(
187
188
189
190
            output,
            query,
            key_cache,
            value_cache,
191
            num_kv_heads,
192
193
            scale,
            block_tables,
194
            seq_lens,
195
            block_size,
196
            max_seq_len,
197
            alibi_slopes,
198
            kv_cache_dtype,
199
200
            k_scale,
            v_scale,
201
        )
202
203
204
205
206
207
208

        opcheck(torch.ops._C.paged_attention_v1,
                (output, query, key_cache, value_cache, num_kv_heads, scale,
                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
                cond=(head_size == HEAD_SIZES[0]))

209
    elif version == "v2":
210
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
211
212
213
214
215
216
217
218
219
220
221
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
222
        ops.paged_attention_v2(
223
224
225
226
227
228
229
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
230
            num_kv_heads,
231
232
            scale,
            block_tables,
233
            seq_lens,
234
            block_size,
235
            max_seq_len,
236
            alibi_slopes,
237
            kv_cache_dtype,
238
239
            k_scale,
            v_scale,
240
        )
241
242
243
244
245
246
247
248

        opcheck(torch.ops._C.paged_attention_v2,
                (output, exp_sums, max_logits, tmp_output, query, key_cache,
                 value_cache, num_kv_heads, scale, block_tables, seq_lens,
                 block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
                 k_scale, v_scale, 0, 0, 0, 64, 0),
                cond=(head_size == HEAD_SIZES[0]))

249
    else:
250
        raise AssertionError(f"Unknown version: {version}")
251

252
    # Run the reference implementation.
253
    if kv_cache_dtype == "fp8":
254
255
256
257
258
259
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
260
                                            device=device)
261
        ops.convert_fp8(dequantized_key_cache, key_cache)
262
263
264
265
266
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
267
                                              device=device)
268
        ops.convert_fp8(dequantized_value_cache, value_cache)
269
270
        value_cache = dequantized_value_cache

271
272
273
274
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
275
        num_queries_per_kv,
276
277
278
        key_cache,
        value_cache,
        block_tables,
279
        seq_lens,
280
281
        scale,
        alibi_slopes,
282
    )
283
284
285
286

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
287
288
289
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5

290
291
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
292
293
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
294
        atol, rtol = 1e-2, 1e-5
295
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
296
297


298
299
300
301
302
303
304
305
306
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
307
    ref_outputs: List[torch.Tensor] = []
308
309
310
311
312
313
314
315
316
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
317
        attn_mask = attn_mask.to(dtype=dtype)
318
319
320
321
322
323
324
325
326

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
327
328

    return torch.cat(ref_outputs, dim=0)
329
330


331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
@pytest.mark.parametrize("version", ["rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64, 128])  # only test 64 128
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(not is_hip(), reason="only for rocm")
def test_paged_attention_rocm(
    kv_cache_factory,
    version: str,
    num_seqs: int,
    num_heads: Tuple[int, int],
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    seed: int,
    device: str,
) -> None:
355
    seed_everything(seed)
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
    torch.set_default_device(device)
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)

    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    context_lens[-1] = MAX_SEQ_LEN
    #context_lens = [8192 for _ in range(num_seqs)]
    max_context_len = max(context_lens)
    context_lens = torch.tensor(context_lens, dtype=torch.int)
    #print('>>> ctx lens', context_lens)

    # Create the block tables.
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
    for _ in range(num_seqs):
        block_table = [
            random.randint(0, NUM_BLOCKS - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
    block_tables = torch.tensor(block_tables, dtype=torch.int)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
    key_cache, value_cache = key_caches[0], value_caches[0]

    # TODO(charlifu) enable fp8 kv cache
    # Using default kv_scale
    # kv_scale = 1.0

    # Call the paged attention kernel.
    output = torch.empty_like(query)
    PARTITION_SIZE_ROCM = 256
    num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
                      PARTITION_SIZE_ROCM)
    assert PARTITION_SIZE_ROCM % block_size == 0
    num_seqs, num_heads, head_size = output.shape
    tmp_output = torch.empty(
        size=(num_seqs, num_heads, num_partitions, head_size),
        dtype=output.dtype,
    )
    exp_sums = torch.empty(
        size=(num_seqs, num_heads, num_partitions),
        dtype=torch.float32,
    )
    max_logits = torch.empty_like(exp_sums)
    if version == "rocm":
        ops.paged_attention_rocm(
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
            num_kv_heads,
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
            kv_cache_dtype,
        )
    else:
        raise AssertionError(f"Unknown version: {version}")

    # Run the reference implementation.
    if kv_cache_dtype == "fp8":
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
                                            device=device)
        ops.convert_fp8(key_cache, dequantized_key_cache)
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
                                              device=device)
        ops.convert_fp8(value_cache, dequantized_value_cache)
        value_cache = dequantized_value_cache

    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
        num_queries_per_kv,
        key_cache,
        value_cache,
        block_tables,
        context_lens,
        scale,
        alibi_slopes,
    )

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5

    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
    atol, rtol = 1e-4, 1e-5
    if dtype == torch.bfloat16:
        atol, rtol = 2e-4, 1e-5
    if use_alibi:
        if dtype == torch.half:
            atol, rtol = 5e-4, 1e-5
        if dtype == torch.bfloat16:
            atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
        atol, rtol = 1e-2, 1e-5
    assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)


487
# TODO(woosuk): Add tests for USE_ALIBI=True.
488
489
490
491
492
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
493
@pytest.mark.parametrize("device", CUDA_DEVICES)
494
@pytest.mark.skipif(is_hip(), reason="skip for rocm")
495
@torch.inference_mode()
496
def test_multi_query_kv_attention(
497
    num_seqs: int,
498
    num_heads: Tuple[int, int],
499
500
    head_size: int,
    dtype: torch.dtype,
501
    seed: int,
502
    device: str,
503
) -> None:
504
    seed_everything(seed)
505
    torch.set_default_device(device)
506
507
508
509
510
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
511
512
    num_tokens = sum(seq_lens)

513
    scale = float(1.0 / (head_size**0.5))
514
    num_query_heads, num_kv_heads = num_heads
515
    qkv = torch.empty(num_tokens,
516
                      num_query_heads + 2 * num_kv_heads,
517
                      head_size,
518
                      dtype=dtype)
519
520
521
522
523
524
525
526
527
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
528
529
530
531
532
533
534
535
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
536
    )
537
    output = output.squeeze(0)
538

539
540
541
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
542
543
544
545
546
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
547
        scale,
548
549
        dtype,
    )
550
551
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5
552
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)