test_attention.py 18.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import random
5
from typing import Optional
6

7
import pytest
8
9
import torch

10
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
11
from tests.kernels.utils import opcheck
12
from vllm import _custom_ops as ops
13
from vllm.attention.layer import Attention, MultiHeadAttention
14
from vllm.platforms import current_platform
15
from vllm.utils import get_max_shared_memory_bytes
16

17
if not current_platform.is_rocm():
18
19
20
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

21
22
    from vllm.attention.backends.xformers import _make_alibi_bias

23
24
25
26
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
27
28
29
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
30
PARTITION_SIZE = 512
31
PARTITION_SIZE_ROCM = 256
32
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
33
34
35
DTYPES = [
    torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
36
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
37
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
38
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
39

40
41
42
# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
43

44
BLOCK_SIZES = [16, 32]
45
USE_ALIBI = [False, True]
46
KV_CACHE_DTYPE = ["auto", "fp8"]
47
SEEDS = [0]
48
49
50
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
51

52
53
54
55
56
57
58
59

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
60
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
61
    if attn_mask is not None:
62
63
64
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
65
66
67
68
69
70
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
71
    num_queries_per_kv: int,
72
73
74
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
75
    seq_lens: torch.Tensor,
76
77
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
78
) -> None:
79
80
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
81
82
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
83
    num_seqs = query.shape[0]
84

85
86
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
87
    for i in range(num_seqs):
88
        q = query[i].unsqueeze(0)
89
90
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
91

92
93
        keys_lst: list[torch.Tensor] = []
        values_lst: list[torch.Tensor] = []
94
        for j in range(seq_len):
95
96
97
98
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
99
            k = k.reshape(num_kv_heads, head_size)
100
            keys_lst.append(k)
101
102

            v = value_cache[block_number, :, :, block_offset]
103
104
105
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
106
107
108
109
110
111
112
113
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
114
115
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
116
117
118
119
120
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
121
122
123
        output[i].copy_(out, non_blocking=True)


124
@pytest.mark.parametrize(
125
126
    "version",
    ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
127
128
129
130
131
132
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
133
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
134
@pytest.mark.parametrize("seed", SEEDS)
135
@pytest.mark.parametrize("device", CUDA_DEVICES)
136
def test_paged_attention(
137
    kv_cache_factory,
138
    version: str,
139
    num_seqs: int,
140
    num_heads: tuple[int, int],
141
    head_size: int,
142
    use_alibi: bool,
143
144
    block_size: int,
    dtype: torch.dtype,
145
    kv_cache_dtype: str,
146
    seed: int,
147
    device: str,
148
) -> None:
149
150
    if ((kv_cache_dtype == "fp8" and head_size % 16)
            or (version == "rocm" and head_size not in (64, 128))):
Joe's avatar
Joe committed
151
        pytest.skip()
152

153
154
155
156
157
    if (version == "rocm" and current_platform.is_navi()
            and (kv_cache_dtype == "fp8" or head_size != 128
                 or block_size != 16 or use_alibi)):
        pytest.skip()

158
159
    global PARTITION_SIZE

160
    current_platform.seed_everything(seed)
161
    torch.set_default_device(device)
162
163
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
164
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
165
166
167
168
169
170
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
171
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
172

173
174
175
176
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
177

178
    # Create the block tables.
179
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
180
    block_tables_lst: list[list[int]] = []
181
    for _ in range(num_seqs):
182
        block_table = [
183
            random.randint(0, NUM_BLOCKS - 1)
184
185
            for _ in range(max_num_blocks_per_seq)
        ]
186
187
188
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
189

190
191
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
192
193
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
194
                                                device)
195
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
196

197
    # Using default kv_scale
198
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
199

200
201
    # Call the paged attention kernel.
    output = torch.empty_like(query)
202
    if version == "v1":
203
        ops.paged_attention_v1(
204
205
206
207
            output,
            query,
            key_cache,
            value_cache,
208
            num_kv_heads,
209
210
            scale,
            block_tables,
211
            seq_lens,
212
            block_size,
213
            max_seq_len,
214
            alibi_slopes,
215
            kv_cache_dtype,
216
217
            k_scale,
            v_scale,
218
        )
219
220
221
222
223

        opcheck(torch.ops._C.paged_attention_v1,
                (output, query, key_cache, value_cache, num_kv_heads, scale,
                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
224
225
                cond=(head_size == HEAD_SIZES[0]
                      and block_size == BLOCK_SIZES[0]))
226

227
    elif version in ("v2", "rocm"):
228
229
230
        if current_platform.is_rocm() and version == "rocm":
            PARTITION_SIZE = PARTITION_SIZE_ROCM

231
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
232
233
234
235
236
237
238
239
240
241
242
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        if version == "v2":
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._C.paged_attention_v2,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
269
270
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
271
272
273
274
275
276
277
278
279
280
281
282
283
284

        else:
            ops.paged_attention_rocm(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
285
                None,
286
287
288
289
290
291
292
293
294
295
296
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._rocm_C.paged_attention,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
297
                     seq_lens, None, block_size, max_seq_len, alibi_slopes,
298
                     kv_cache_dtype, k_scale, v_scale),
299
300
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
301

302
    else:
303
        raise AssertionError(f"Unknown version: {version}")
304

305
    # Run the reference implementation.
306
    if kv_cache_dtype == "fp8":
307
308
309
310
311
312
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
313
                                            device=device)
314
        ops.convert_fp8(dequantized_key_cache, key_cache)
315
316
317
318
319
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
320
                                              device=device)
321
        ops.convert_fp8(dequantized_value_cache, value_cache)
322
323
        value_cache = dequantized_value_cache

324
325
326
327
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
328
        num_queries_per_kv,
329
330
331
        key_cache,
        value_cache,
        block_tables,
332
        seq_lens,
333
334
        scale,
        alibi_slopes,
335
    )
336
337
338
339

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
340
341
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
342

343
344
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
345
346
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
347
        atol, rtol = 1e-2, 1e-5
348
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
349
350


351
def ref_multi_query_kv_attention(
352
    cu_seq_lens: list[int],
353
354
355
356
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
357
    alibi_bias: Optional[list[torch.Tensor]],
358
359
360
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
361
    ref_outputs: list[torch.Tensor] = []
362
363
    if alibi_bias:
        assert len(alibi_bias) == num_seqs
364
365
366
367
368
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

369
370
371
372
373
374
375
376
        # Create attention mask. ALiBi already includes a tril causal mask.
        if alibi_bias:
            attn_mask = alibi_bias[i]
        else:
            attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                                   diagonal=1)
            attn_mask = attn_mask * torch.finfo(dtype).min
            attn_mask = attn_mask.to(dtype=dtype)
377
378
379
380
381
382
383
384
385

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
386
387

    return torch.cat(ref_outputs, dim=0)
388
389
390
391
392
393
394


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
395
@pytest.mark.parametrize("device", CUDA_DEVICES)
396
@pytest.mark.skipif(current_platform.is_rocm(),
397
                    reason="Xformers backend is not supported on ROCm.")
398
@torch.inference_mode()
399
def test_multi_query_kv_attention(
400
    num_seqs: int,
401
    num_heads: tuple[int, int],
402
403
    head_size: int,
    dtype: torch.dtype,
404
    seed: int,
405
    device: str,
406
    use_alibi: bool = False,
407
) -> None:
408
    current_platform.seed_everything(seed)
409
    torch.set_default_device(device)
410
411
412
413
414
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
415
416
    num_tokens = sum(seq_lens)

417
    scale = float(1.0 / (head_size**0.5))
418
    num_query_heads, num_kv_heads = num_heads
419
    qkv = torch.empty(num_tokens,
420
                      num_query_heads + 2 * num_kv_heads,
421
                      head_size,
422
                      dtype=dtype)
423
424
425
426
427
428
429
430
431
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    alibi_bias = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
        attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
                                     seq_lens)
        output = torch.empty_like(query)
        start = 0
        # Dynamic sequence length not supported with custom attn_bias.
        for i, seq_len in enumerate(seq_lens):
            end = start + seq_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=attn_bias[i],
                p=0.0,
                scale=scale)
            output[start:end].copy_(out.view_as(query[start:end]))
            start += seq_len
        # xformers.AttentionBias to Tensor for use in reference impl.
        alibi_bias = [
            b.materialize(b.shape, device=device).squeeze() for b in attn_bias
        ]
    else:
        attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
        output = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=attn_bias,
            p=0.0,
            scale=scale,
        )
        output = output.squeeze(0)
466

467
468
469
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
470
471
472
473
474
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
475
        scale,
476
        alibi_bias,
477
478
        dtype,
    )
479
480
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
    num_seqs: int,
    num_heads: tuple[int, int],
    head_size: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    return test_multi_query_kv_attention(
        num_seqs,
        num_heads,
        head_size,
        dtype,
        seed,
        device,
        use_alibi=True,
    )
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524


@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
    head_size = 64
    scale = float(1.0 / (head_size**0.5))
    num_heads = 16
    num_kv_heads = 5
    with pytest.raises(AssertionError):
        _ = attention_cls(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
        )