test_attention.py 17.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import random
4
from typing import Optional
5

6
import pytest
7
8
import torch

9
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
10
from tests.kernels.utils import opcheck
11
from vllm import _custom_ops as ops
12
from vllm.platforms import current_platform
13
from vllm.utils import get_max_shared_memory_bytes
14

15
if not current_platform.is_rocm():
16
17
18
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

19
20
    from vllm.attention.backends.xformers import _make_alibi_bias

21
22
23
24
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
25
26
27
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
28
PARTITION_SIZE = 512
29
PARTITION_SIZE_ROCM = 256
30
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
31
32
33
DTYPES = [
    torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
34
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
35
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
36
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
37

38
39
40
# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
41

42
BLOCK_SIZES = [16, 32]
43
USE_ALIBI = [False, True]
44
KV_CACHE_DTYPE = ["auto", "fp8"]
45
SEEDS = [0]
46
47
48
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
49

50
51
52
53
54
55
56
57

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
58
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
59
    if attn_mask is not None:
60
61
62
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
63
64
65
66
67
68
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
69
    num_queries_per_kv: int,
70
71
72
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
73
    seq_lens: torch.Tensor,
74
75
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
76
) -> None:
77
78
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
79
80
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
81
    num_seqs = query.shape[0]
82

83
84
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
85
    for i in range(num_seqs):
86
        q = query[i].unsqueeze(0)
87
88
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
89

90
91
        keys_lst: list[torch.Tensor] = []
        values_lst: list[torch.Tensor] = []
92
        for j in range(seq_len):
93
94
95
96
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
97
            k = k.reshape(num_kv_heads, head_size)
98
            keys_lst.append(k)
99
100

            v = value_cache[block_number, :, :, block_offset]
101
102
103
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
104
105
106
107
108
109
110
111
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
112
113
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
114
115
116
117
118
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
119
120
121
        output[i].copy_(out, non_blocking=True)


122
@pytest.mark.parametrize(
123
124
    "version",
    ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
125
126
127
128
129
130
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
131
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
132
@pytest.mark.parametrize("seed", SEEDS)
133
@pytest.mark.parametrize("device", CUDA_DEVICES)
134
def test_paged_attention(
135
    kv_cache_factory,
136
    version: str,
137
    num_seqs: int,
138
    num_heads: tuple[int, int],
139
    head_size: int,
140
    use_alibi: bool,
141
142
    block_size: int,
    dtype: torch.dtype,
143
    kv_cache_dtype: str,
144
    seed: int,
145
    device: str,
146
) -> None:
147
148
    if ((kv_cache_dtype == "fp8" and head_size % 16)
            or (version == "rocm" and head_size not in (64, 128))):
Joe's avatar
Joe committed
149
        pytest.skip()
150

151
152
153
154
155
    if (version == "rocm" and current_platform.is_navi()
            and (kv_cache_dtype == "fp8" or head_size != 128
                 or block_size != 16 or use_alibi)):
        pytest.skip()

156
157
    global PARTITION_SIZE

158
    current_platform.seed_everything(seed)
159
    torch.set_default_device(device)
160
161
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
162
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
163
164
165
166
167
168
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
169
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
170

171
172
173
174
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
175

176
    # Create the block tables.
177
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
178
    block_tables_lst: list[list[int]] = []
179
    for _ in range(num_seqs):
180
        block_table = [
181
            random.randint(0, NUM_BLOCKS - 1)
182
183
            for _ in range(max_num_blocks_per_seq)
        ]
184
185
186
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
187

188
189
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
190
191
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
192
                                                device)
193
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
194

195
    # Using default kv_scale
196
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
197

198
199
    # Call the paged attention kernel.
    output = torch.empty_like(query)
200
    if version == "v1":
201
        ops.paged_attention_v1(
202
203
204
205
            output,
            query,
            key_cache,
            value_cache,
206
            num_kv_heads,
207
208
            scale,
            block_tables,
209
            seq_lens,
210
            block_size,
211
            max_seq_len,
212
            alibi_slopes,
213
            kv_cache_dtype,
214
215
            k_scale,
            v_scale,
216
        )
217
218
219
220
221

        opcheck(torch.ops._C.paged_attention_v1,
                (output, query, key_cache, value_cache, num_kv_heads, scale,
                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
222
223
                cond=(head_size == HEAD_SIZES[0]
                      and block_size == BLOCK_SIZES[0]))
224

225
    elif version in ("v2", "rocm"):
226
227
228
        if current_platform.is_rocm() and version == "rocm":
            PARTITION_SIZE = PARTITION_SIZE_ROCM

229
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
230
231
232
233
234
235
236
237
238
239
240
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        if version == "v2":
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._C.paged_attention_v2,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
267
268
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
269
270
271
272
273
274
275
276
277
278
279
280
281
282

        else:
            ops.paged_attention_rocm(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
283
                None,
284
285
286
287
288
289
290
291
292
293
294
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._rocm_C.paged_attention,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
295
                     seq_lens, None, block_size, max_seq_len, alibi_slopes,
296
                     kv_cache_dtype, k_scale, v_scale),
297
298
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
299

300
    else:
301
        raise AssertionError(f"Unknown version: {version}")
302

303
    # Run the reference implementation.
304
    if kv_cache_dtype == "fp8":
305
306
307
308
309
310
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
311
                                            device=device)
312
        ops.convert_fp8(dequantized_key_cache, key_cache)
313
314
315
316
317
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
318
                                              device=device)
319
        ops.convert_fp8(dequantized_value_cache, value_cache)
320
321
        value_cache = dequantized_value_cache

322
323
324
325
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
326
        num_queries_per_kv,
327
328
329
        key_cache,
        value_cache,
        block_tables,
330
        seq_lens,
331
332
        scale,
        alibi_slopes,
333
    )
334
335
336
337

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
338
339
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
340

341
342
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
343
344
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
345
        atol, rtol = 1e-2, 1e-5
346
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
347
348


349
def ref_multi_query_kv_attention(
350
    cu_seq_lens: list[int],
351
352
353
354
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
355
    alibi_bias: Optional[list[torch.Tensor]],
356
357
358
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
359
    ref_outputs: list[torch.Tensor] = []
360
361
    if alibi_bias:
        assert len(alibi_bias) == num_seqs
362
363
364
365
366
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

367
368
369
370
371
372
373
374
        # Create attention mask. ALiBi already includes a tril causal mask.
        if alibi_bias:
            attn_mask = alibi_bias[i]
        else:
            attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                                   diagonal=1)
            attn_mask = attn_mask * torch.finfo(dtype).min
            attn_mask = attn_mask.to(dtype=dtype)
375
376
377
378
379
380
381
382
383

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
384
385

    return torch.cat(ref_outputs, dim=0)
386
387
388
389
390
391
392


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
393
@pytest.mark.parametrize("device", CUDA_DEVICES)
394
@pytest.mark.skipif(current_platform.is_rocm(),
395
                    reason="Xformers backend is not supported on ROCm.")
396
@torch.inference_mode()
397
def test_multi_query_kv_attention(
398
    num_seqs: int,
399
    num_heads: tuple[int, int],
400
401
    head_size: int,
    dtype: torch.dtype,
402
    seed: int,
403
    device: str,
404
    use_alibi: bool = False,
405
) -> None:
406
    current_platform.seed_everything(seed)
407
    torch.set_default_device(device)
408
409
410
411
412
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
413
414
    num_tokens = sum(seq_lens)

415
    scale = float(1.0 / (head_size**0.5))
416
    num_query_heads, num_kv_heads = num_heads
417
    qkv = torch.empty(num_tokens,
418
                      num_query_heads + 2 * num_kv_heads,
419
                      head_size,
420
                      dtype=dtype)
421
422
423
424
425
426
427
428
429
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    alibi_bias = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
        attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
                                     seq_lens)
        output = torch.empty_like(query)
        start = 0
        # Dynamic sequence length not supported with custom attn_bias.
        for i, seq_len in enumerate(seq_lens):
            end = start + seq_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=attn_bias[i],
                p=0.0,
                scale=scale)
            output[start:end].copy_(out.view_as(query[start:end]))
            start += seq_len
        # xformers.AttentionBias to Tensor for use in reference impl.
        alibi_bias = [
            b.materialize(b.shape, device=device).squeeze() for b in attn_bias
        ]
    else:
        attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
        output = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=attn_bias,
            p=0.0,
            scale=scale,
        )
        output = output.squeeze(0)
464

465
466
467
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
468
469
470
471
472
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
473
        scale,
474
        alibi_bias,
475
476
        dtype,
    )
477
478
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
    num_seqs: int,
    num_heads: tuple[int, int],
    head_size: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    return test_multi_query_kv_attention(
        num_seqs,
        num_heads,
        head_size,
        dtype,
        seed,
        device,
        use_alibi=True,
    )