test_attention.py 17.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import random
5
from typing import Optional
6

7
import pytest
8
9
import torch

10
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
11
from tests.kernels.utils import opcheck
12
from vllm import _custom_ops as ops
13
from vllm.platforms import current_platform
14
from vllm.utils import get_max_shared_memory_bytes
15

16
if not current_platform.is_rocm():
17
18
19
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

20
21
    from vllm.attention.backends.xformers import _make_alibi_bias

22
23
24
25
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
26
27
28
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
29
PARTITION_SIZE = 512
30
PARTITION_SIZE_ROCM = 256
31
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
32
33
34
DTYPES = [
    torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
35
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
36
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
37
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
38

39
40
41
# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
42

43
BLOCK_SIZES = [16, 32]
44
USE_ALIBI = [False, True]
zhuwenwen's avatar
zhuwenwen committed
45
KV_CACHE_DTYPE = ["auto", "fp8"] if not current_platform.is_rocm() else ["auto"]
46
SEEDS = [0]
47
48
49
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
50

51
52
53
54
55
56
57
58

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
59
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
60
    if attn_mask is not None:
61
62
63
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
64
65
66
67
68
69
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
70
    num_queries_per_kv: int,
71
72
73
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
74
    seq_lens: torch.Tensor,
75
76
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
77
) -> None:
78
79
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
80
81
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
82
    num_seqs = query.shape[0]
83

84
85
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
86
    for i in range(num_seqs):
87
        q = query[i].unsqueeze(0)
88
89
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
90

91
92
        keys_lst: list[torch.Tensor] = []
        values_lst: list[torch.Tensor] = []
93
        for j in range(seq_len):
94
95
96
97
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
98
            k = k.reshape(num_kv_heads, head_size)
99
            keys_lst.append(k)
100
101

            v = value_cache[block_number, :, :, block_offset]
102
103
104
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
105
106
107
108
109
110
111
112
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
113
114
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
115
116
117
118
119
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
120
121
122
        output[i].copy_(out, non_blocking=True)


123
@pytest.mark.parametrize(
124
    "version",
zhuwenwen's avatar
zhuwenwen committed
125
    ["v1", "v2"] if current_platform.is_rocm() else ["v1", "v2", "rocm"])
126
127
128
129
130
131
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
132
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
133
@pytest.mark.parametrize("seed", SEEDS)
134
@pytest.mark.parametrize("device", CUDA_DEVICES)
135
def test_paged_attention(
136
    kv_cache_factory,
137
    version: str,
138
    num_seqs: int,
139
    num_heads: tuple[int, int],
140
    head_size: int,
141
    use_alibi: bool,
142
143
    block_size: int,
    dtype: torch.dtype,
144
    kv_cache_dtype: str,
145
    seed: int,
146
    device: str,
147
) -> None:
148
149
    if ((kv_cache_dtype == "fp8" and head_size % 16)
            or (version == "rocm" and head_size not in (64, 128))):
Joe's avatar
Joe committed
150
        pytest.skip()
151

152
153
154
155
156
    if (version == "rocm" and current_platform.is_navi()
            and (kv_cache_dtype == "fp8" or head_size != 128
                 or block_size != 16 or use_alibi)):
        pytest.skip()

157
158
    global PARTITION_SIZE

159
    current_platform.seed_everything(seed)
160
    torch.set_default_device(device)
161
162
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
163
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
164
165
166
167
168
169
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
170
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
171

172
173
174
175
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
176

177
    # Create the block tables.
178
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
179
    block_tables_lst: list[list[int]] = []
180
    for _ in range(num_seqs):
181
        block_table = [
182
            random.randint(0, NUM_BLOCKS - 1)
183
184
            for _ in range(max_num_blocks_per_seq)
        ]
185
186
187
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
188

189
190
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
191
192
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
193
                                                device)
194
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
195

196
    # Using default kv_scale
197
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
198

199
200
    # Call the paged attention kernel.
    output = torch.empty_like(query)
201
    if version == "v1":
202
        ops.paged_attention_v1(
203
204
205
206
            output,
            query,
            key_cache,
            value_cache,
207
            num_kv_heads,
208
209
            scale,
            block_tables,
210
            seq_lens,
211
            block_size,
212
            max_seq_len,
213
            alibi_slopes,
214
            kv_cache_dtype,
215
216
            k_scale,
            v_scale,
217
        )
218
219
220
221
222

        opcheck(torch.ops._C.paged_attention_v1,
                (output, query, key_cache, value_cache, num_kv_heads, scale,
                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
223
224
                cond=(head_size == HEAD_SIZES[0]
                      and block_size == BLOCK_SIZES[0]))
225

226
    elif version in ("v2", "rocm"):
227
228
229
        if current_platform.is_rocm() and version == "rocm":
            PARTITION_SIZE = PARTITION_SIZE_ROCM

230
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
231
232
233
234
235
236
237
238
239
240
241
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        if version == "v2":
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._C.paged_attention_v2,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
268
269
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
270
271
272
273
274
275
276
277
278
279
280
281
282
283

        else:
            ops.paged_attention_rocm(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
284
                None,
285
286
287
288
289
290
291
292
293
294
295
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._rocm_C.paged_attention,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
296
                     seq_lens, None, block_size, max_seq_len, alibi_slopes,
297
                     kv_cache_dtype, k_scale, v_scale),
298
299
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
300

301
    else:
302
        raise AssertionError(f"Unknown version: {version}")
303

304
    # Run the reference implementation.
305
    if kv_cache_dtype == "fp8":
306
307
308
309
310
311
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
312
                                            device=device)
313
        ops.convert_fp8(dequantized_key_cache, key_cache)
314
315
316
317
318
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
319
                                              device=device)
320
        ops.convert_fp8(dequantized_value_cache, value_cache)
321
322
        value_cache = dequantized_value_cache

323
324
325
326
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
327
        num_queries_per_kv,
328
329
330
        key_cache,
        value_cache,
        block_tables,
331
        seq_lens,
332
333
        scale,
        alibi_slopes,
334
    )
335
336
337
338

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
339
340
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
341

342
343
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
344
345
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
346
        atol, rtol = 1e-2, 1e-5
347
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
348
349


350
def ref_multi_query_kv_attention(
351
    cu_seq_lens: list[int],
352
353
354
355
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
356
    alibi_bias: Optional[list[torch.Tensor]],
357
358
359
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
360
    ref_outputs: list[torch.Tensor] = []
361
362
    if alibi_bias:
        assert len(alibi_bias) == num_seqs
363
364
365
366
367
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

368
369
370
371
372
373
374
375
        # Create attention mask. ALiBi already includes a tril causal mask.
        if alibi_bias:
            attn_mask = alibi_bias[i]
        else:
            attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                                   diagonal=1)
            attn_mask = attn_mask * torch.finfo(dtype).min
            attn_mask = attn_mask.to(dtype=dtype)
376
377
378
379
380
381
382
383
384

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
385
386

    return torch.cat(ref_outputs, dim=0)
387
388
389
390
391
392
393


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
394
@pytest.mark.parametrize("device", CUDA_DEVICES)
395
@pytest.mark.skipif(current_platform.is_rocm(),
396
                    reason="Xformers backend is not supported on ROCm.")
397
@torch.inference_mode()
398
def test_multi_query_kv_attention(
399
    num_seqs: int,
400
    num_heads: tuple[int, int],
401
402
    head_size: int,
    dtype: torch.dtype,
403
    seed: int,
404
    device: str,
405
    use_alibi: bool = False,
406
) -> None:
407
    current_platform.seed_everything(seed)
408
    torch.set_default_device(device)
409
410
411
412
413
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
414
415
    num_tokens = sum(seq_lens)

416
    scale = float(1.0 / (head_size**0.5))
417
    num_query_heads, num_kv_heads = num_heads
418
    qkv = torch.empty(num_tokens,
419
                      num_query_heads + 2 * num_kv_heads,
420
                      head_size,
421
                      dtype=dtype)
422
423
424
425
426
427
428
429
430
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    alibi_bias = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
        attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
                                     seq_lens)
        output = torch.empty_like(query)
        start = 0
        # Dynamic sequence length not supported with custom attn_bias.
        for i, seq_len in enumerate(seq_lens):
            end = start + seq_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=attn_bias[i],
                p=0.0,
                scale=scale)
            output[start:end].copy_(out.view_as(query[start:end]))
            start += seq_len
        # xformers.AttentionBias to Tensor for use in reference impl.
        alibi_bias = [
            b.materialize(b.shape, device=device).squeeze() for b in attn_bias
        ]
    else:
        attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
        output = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=attn_bias,
            p=0.0,
            scale=scale,
        )
        output = output.squeeze(0)
465

466
467
468
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
469
470
471
472
473
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
474
        scale,
475
        alibi_bias,
476
477
        dtype,
    )
478
479
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
480
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
    num_seqs: int,
    num_heads: tuple[int, int],
    head_size: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    return test_multi_query_kv_attention(
        num_seqs,
        num_heads,
        head_size,
        dtype,
        seed,
        device,
        use_alibi=True,
    )