test_attention.py 19.3 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
6
import torch

7
from tests.kernels.utils import opcheck
8
from vllm import _custom_ops as ops
9
from vllm.utils import get_max_shared_memory_bytes, is_hip
10

11
12
from .allclose_default import get_default_atol, get_default_rtol

13
14
15
16
if not is_hip():
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

17
18
19
20
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
21
22
23
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
24
PARTITION_SIZE = 512
25
26
27
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
          ] if not is_hip() else [torch.half, torch.bfloat16]
28
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
29
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
30
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
31
32
33

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
Joe's avatar
Joe committed
34
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
35
36
              ] if not is_hip() else [64, 80, 96, 112, 128]

37
BLOCK_SIZES = [16, 32]
38
USE_ALIBI = [False, True]
39
KV_CACHE_DTYPE = ["auto", "fp8"]
40
SEEDS = [0]
41
42
43
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
44

45
46
47
48
49
50
51
52

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
53
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
54
    if attn_mask is not None:
55
56
57
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
58
59
60
61
62
63
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
64
    num_queries_per_kv: int,
65
66
67
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
68
    seq_lens: torch.Tensor,
69
70
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
71
) -> None:
72
73
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
74
75
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
76
    num_seqs = query.shape[0]
77

78
79
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
80
    for i in range(num_seqs):
81
        q = query[i].unsqueeze(0)
82
83
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
84

85
86
        keys_lst: List[torch.Tensor] = []
        values_lst: List[torch.Tensor] = []
87
        for j in range(seq_len):
88
89
90
91
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
92
            k = k.reshape(num_kv_heads, head_size)
93
            keys_lst.append(k)
94
95

            v = value_cache[block_number, :, :, block_offset]
96
97
98
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
99
100
101
102
103
104
105
106
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
107
108
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
109
110
111
112
113
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
114
115
116
        output[i].copy_(out, non_blocking=True)


117
@pytest.mark.parametrize("version", ["v1", "v2"])
118
119
120
121
122
123
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
124
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
125
@pytest.mark.parametrize("seed", SEEDS)
126
@pytest.mark.parametrize("device", CUDA_DEVICES)
127
def test_paged_attention(
128
    kv_cache_factory,
129
    version: str,
130
131
    num_seqs: int,
    num_heads: Tuple[int, int],
132
    head_size: int,
133
    use_alibi: bool,
134
135
    block_size: int,
    dtype: torch.dtype,
136
    kv_cache_dtype: str,
137
    seed: int,
138
    device: str,
139
) -> None:
Joe's avatar
Joe committed
140
141
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
142
143
    random.seed(seed)
    torch.random.manual_seed(seed)
144
145
146
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
147
148
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
149
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
150
151
152
153
154
155
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
156
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
157

158
159
160
161
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
162

163
    # Create the block tables.
164
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
165
    block_tables_lst: List[List[int]] = []
166
    for _ in range(num_seqs):
167
        block_table = [
168
            random.randint(0, NUM_BLOCKS - 1)
169
170
            for _ in range(max_num_blocks_per_seq)
        ]
171
172
173
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
174

175
176
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
177
178
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
179
                                                device)
180
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
181

182
    # Using default kv_scale
183
    k_scale = v_scale = 1.0
184

185
186
    # Call the paged attention kernel.
    output = torch.empty_like(query)
187
    if version == "v1":
188
        ops.paged_attention_v1(
189
190
191
192
            output,
            query,
            key_cache,
            value_cache,
193
            num_kv_heads,
194
195
            scale,
            block_tables,
196
            seq_lens,
197
            block_size,
198
            max_seq_len,
199
            alibi_slopes,
200
            kv_cache_dtype,
201
202
            k_scale,
            v_scale,
203
        )
204
205
206
207
208
209
210

        opcheck(torch.ops._C.paged_attention_v1,
                (output, query, key_cache, value_cache, num_kv_heads, scale,
                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
                cond=(head_size == HEAD_SIZES[0]))

211
    elif version == "v2":
212
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
213
214
215
216
217
218
219
220
221
222
223
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
224
        ops.paged_attention_v2(
225
226
227
228
229
230
231
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
232
            num_kv_heads,
233
234
            scale,
            block_tables,
235
            seq_lens,
236
            block_size,
237
            max_seq_len,
238
            alibi_slopes,
239
            kv_cache_dtype,
240
241
            k_scale,
            v_scale,
242
        )
243
244
245
246
247
248
249
250

        opcheck(torch.ops._C.paged_attention_v2,
                (output, exp_sums, max_logits, tmp_output, query, key_cache,
                 value_cache, num_kv_heads, scale, block_tables, seq_lens,
                 block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
                 k_scale, v_scale, 0, 0, 0, 64, 0),
                cond=(head_size == HEAD_SIZES[0]))

251
    else:
252
        raise AssertionError(f"Unknown version: {version}")
253

254
    # Run the reference implementation.
255
    if kv_cache_dtype == "fp8":
256
257
258
259
260
261
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
262
                                            device=device)
263
        ops.convert_fp8(dequantized_key_cache, key_cache)
264
265
266
267
268
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
269
                                              device=device)
270
        ops.convert_fp8(dequantized_value_cache, value_cache)
271
272
        value_cache = dequantized_value_cache

273
274
275
276
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
277
        num_queries_per_kv,
278
279
280
        key_cache,
        value_cache,
        block_tables,
281
        seq_lens,
282
283
        scale,
        alibi_slopes,
284
    )
285
286
287
288

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
289
290
291
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5

292
293
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
294
295
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
296
        atol, rtol = 1e-2, 1e-5
297
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
298
299


300
301
302
303
304
305
306
307
308
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
309
    ref_outputs: List[torch.Tensor] = []
310
311
312
313
314
315
316
317
318
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
319
        attn_mask = attn_mask.to(dtype=dtype)
320
321
322
323
324
325
326
327
328

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
329
330

    return torch.cat(ref_outputs, dim=0)
331
332


333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
@pytest.mark.parametrize("version", ["rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64, 128])  # only test 64 128
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(not is_hip(), reason="only for rocm")
def test_paged_attention_rocm(
    kv_cache_factory,
    version: str,
    num_seqs: int,
    num_heads: Tuple[int, int],
    head_size: int,
    use_alibi: bool,
    block_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    seed: int,
    device: str,
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)

    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    context_lens[-1] = MAX_SEQ_LEN
    #context_lens = [8192 for _ in range(num_seqs)]
    max_context_len = max(context_lens)
    context_lens = torch.tensor(context_lens, dtype=torch.int)
    #print('>>> ctx lens', context_lens)

    # Create the block tables.
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
    for _ in range(num_seqs):
        block_table = [
            random.randint(0, NUM_BLOCKS - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
    block_tables = torch.tensor(block_tables, dtype=torch.int)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
    key_cache, value_cache = key_caches[0], value_caches[0]

    # TODO(charlifu) enable fp8 kv cache
    # Using default kv_scale
    # kv_scale = 1.0

    # Call the paged attention kernel.
    output = torch.empty_like(query)
    PARTITION_SIZE_ROCM = 256
    num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
                      PARTITION_SIZE_ROCM)
    assert PARTITION_SIZE_ROCM % block_size == 0
    num_seqs, num_heads, head_size = output.shape
    tmp_output = torch.empty(
        size=(num_seqs, num_heads, num_partitions, head_size),
        dtype=output.dtype,
    )
    exp_sums = torch.empty(
        size=(num_seqs, num_heads, num_partitions),
        dtype=torch.float32,
    )
    max_logits = torch.empty_like(exp_sums)
    if version == "rocm":
        ops.paged_attention_rocm(
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
            num_kv_heads,
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
            kv_cache_dtype,
        )
    else:
        raise AssertionError(f"Unknown version: {version}")

    # Run the reference implementation.
    if kv_cache_dtype == "fp8":
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
                                            device=device)
        ops.convert_fp8(key_cache, dequantized_key_cache)
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
                                              device=device)
        ops.convert_fp8(value_cache, dequantized_value_cache)
        value_cache = dequantized_value_cache

    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
        num_queries_per_kv,
        key_cache,
        value_cache,
        block_tables,
        context_lens,
        scale,
        alibi_slopes,
    )

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5

    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
    atol, rtol = 1e-4, 1e-5
    if dtype == torch.bfloat16:
        atol, rtol = 2e-4, 1e-5
    if use_alibi:
        if dtype == torch.half:
            atol, rtol = 5e-4, 1e-5
        if dtype == torch.bfloat16:
            atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
        atol, rtol = 1e-2, 1e-5
    assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)


492
# TODO(woosuk): Add tests for USE_ALIBI=True.
493
494
495
496
497
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
498
@pytest.mark.parametrize("device", CUDA_DEVICES)
499
@pytest.mark.skipif(is_hip(), reason="skip for rocm")
500
@torch.inference_mode()
501
def test_multi_query_kv_attention(
502
    num_seqs: int,
503
    num_heads: Tuple[int, int],
504
505
    head_size: int,
    dtype: torch.dtype,
506
    seed: int,
507
    device: str,
508
) -> None:
509
510
    random.seed(seed)
    torch.random.manual_seed(seed)
511
512
513
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
514
515
516
517
518
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
519
520
    num_tokens = sum(seq_lens)

521
    scale = float(1.0 / (head_size**0.5))
522
    num_query_heads, num_kv_heads = num_heads
523
    qkv = torch.empty(num_tokens,
524
                      num_query_heads + 2 * num_kv_heads,
525
                      head_size,
526
                      dtype=dtype)
527
528
529
530
531
532
533
534
535
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
536
537
538
539
540
541
542
543
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
544
    )
545
    output = output.squeeze(0)
546

547
548
549
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
550
551
552
553
554
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
555
        scale,
556
557
        dtype,
    )
558
559
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5
560
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)