".github/ISSUE_TEMPLATE/400-bug-report.yml" did not exist on "35ee2ad6b9a850a25d94cab582de19de5bca6fbd"
test_attention.py 17.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import random
4
from typing import Optional
5

6
import pytest
7
8
import torch

9
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
10
from tests.kernels.utils import opcheck
11
from vllm import _custom_ops as ops
12
from vllm.platforms import current_platform
13
from vllm.utils import get_max_shared_memory_bytes
14

15
if not current_platform.is_rocm():
16
17
18
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

19
20
    from vllm.attention.backends.xformers import _make_alibi_bias

21
22
23
24
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
25
26
27
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
28
PARTITION_SIZE = 512
29
PARTITION_SIZE_ROCM = 256
30
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
31
32
33
DTYPES = [
    torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
34
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
35
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
36
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
37

38
39
40
# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
41

42
BLOCK_SIZES = [16, 32]
43
USE_ALIBI = [False, True]
zhuwenwen's avatar
zhuwenwen committed
44
KV_CACHE_DTYPE = ["auto", "fp8"] if not current_platform.is_rocm() else ["auto"]
45
SEEDS = [0]
46
47
48
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
49

50
51
52
53
54
55
56
57

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
58
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
59
    if attn_mask is not None:
60
61
62
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
63
64
65
66
67
68
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
69
    num_queries_per_kv: int,
70
71
72
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
73
    seq_lens: torch.Tensor,
74
75
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
76
) -> None:
77
78
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
79
80
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
81
    num_seqs = query.shape[0]
82

83
84
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
85
    for i in range(num_seqs):
86
        q = query[i].unsqueeze(0)
87
88
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
89

90
91
        keys_lst: list[torch.Tensor] = []
        values_lst: list[torch.Tensor] = []
92
        for j in range(seq_len):
93
94
95
96
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
97
            k = k.reshape(num_kv_heads, head_size)
98
            keys_lst.append(k)
99
100

            v = value_cache[block_number, :, :, block_offset]
101
102
103
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
104
105
106
107
108
109
110
111
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
112
113
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
114
115
116
117
118
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
119
120
121
        output[i].copy_(out, non_blocking=True)


122
@pytest.mark.parametrize(
123
    "version",
zhuwenwen's avatar
zhuwenwen committed
124
    ["v1", "v2"] if current_platform.is_rocm() else ["v1", "v2", "rocm"])
125
126
127
128
129
130
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
131
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
132
@pytest.mark.parametrize("seed", SEEDS)
133
@pytest.mark.parametrize("device", CUDA_DEVICES)
134
def test_paged_attention(
135
    kv_cache_factory,
136
    version: str,
137
    num_seqs: int,
138
    num_heads: tuple[int, int],
139
    head_size: int,
140
    use_alibi: bool,
141
142
    block_size: int,
    dtype: torch.dtype,
143
    kv_cache_dtype: str,
144
    seed: int,
145
    device: str,
146
) -> None:
147
148
    if ((kv_cache_dtype == "fp8" and head_size % 16)
            or (version == "rocm" and head_size not in (64, 128))):
Joe's avatar
Joe committed
149
        pytest.skip()
150

151
152
    global PARTITION_SIZE

153
    current_platform.seed_everything(seed)
154
    torch.set_default_device(device)
155
156
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
157
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
158
159
160
161
162
163
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
164
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
165

166
167
168
169
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
170

171
    # Create the block tables.
172
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
173
    block_tables_lst: list[list[int]] = []
174
    for _ in range(num_seqs):
175
        block_table = [
176
            random.randint(0, NUM_BLOCKS - 1)
177
178
            for _ in range(max_num_blocks_per_seq)
        ]
179
180
181
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
182

183
184
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
185
186
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
187
                                                device)
188
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
189

190
    # Using default kv_scale
191
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
192

193
194
    # Call the paged attention kernel.
    output = torch.empty_like(query)
195
    if version == "v1":
196
        ops.paged_attention_v1(
197
198
199
200
            output,
            query,
            key_cache,
            value_cache,
201
            num_kv_heads,
202
203
            scale,
            block_tables,
204
            seq_lens,
205
            block_size,
206
            max_seq_len,
207
            alibi_slopes,
208
            kv_cache_dtype,
209
210
            k_scale,
            v_scale,
211
        )
212
213
214
215
216

        opcheck(torch.ops._C.paged_attention_v1,
                (output, query, key_cache, value_cache, num_kv_heads, scale,
                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
217
218
                cond=(head_size == HEAD_SIZES[0]
                      and block_size == BLOCK_SIZES[0]))
219

220
    elif version in ("v2", "rocm"):
221
222
223
        if current_platform.is_rocm() and version == "rocm":
            PARTITION_SIZE = PARTITION_SIZE_ROCM

224
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
225
226
227
228
229
230
231
232
233
234
235
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        if version == "v2":
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._C.paged_attention_v2,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
262
263
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

        else:
            ops.paged_attention_rocm(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._rocm_C.paged_attention,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale),
291
292
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
293

294
    else:
295
        raise AssertionError(f"Unknown version: {version}")
296

297
    # Run the reference implementation.
298
    if kv_cache_dtype == "fp8":
299
300
301
302
303
304
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
305
                                            device=device)
306
        ops.convert_fp8(dequantized_key_cache, key_cache)
307
308
309
310
311
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
312
                                              device=device)
313
        ops.convert_fp8(dequantized_value_cache, value_cache)
314
315
        value_cache = dequantized_value_cache

316
317
318
319
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
320
        num_queries_per_kv,
321
322
323
        key_cache,
        value_cache,
        block_tables,
324
        seq_lens,
325
326
        scale,
        alibi_slopes,
327
    )
328
329
330
331

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
332
333
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
334

335
336
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
337
338
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
339
        atol, rtol = 1e-2, 1e-5
340
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
341
342


343
def ref_multi_query_kv_attention(
344
    cu_seq_lens: list[int],
345
346
347
348
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
349
    alibi_bias: Optional[list[torch.Tensor]],
350
351
352
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
353
    ref_outputs: list[torch.Tensor] = []
354
355
    if alibi_bias:
        assert len(alibi_bias) == num_seqs
356
357
358
359
360
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

361
362
363
364
365
366
367
368
        # Create attention mask. ALiBi already includes a tril causal mask.
        if alibi_bias:
            attn_mask = alibi_bias[i]
        else:
            attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                                   diagonal=1)
            attn_mask = attn_mask * torch.finfo(dtype).min
            attn_mask = attn_mask.to(dtype=dtype)
369
370
371
372
373
374
375
376
377

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
378
379

    return torch.cat(ref_outputs, dim=0)
380
381
382
383
384
385
386


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
387
@pytest.mark.parametrize("device", CUDA_DEVICES)
388
@pytest.mark.skipif(current_platform.is_rocm(),
389
                    reason="Xformers backend is not supported on ROCm.")
390
@torch.inference_mode()
391
def test_multi_query_kv_attention(
392
    num_seqs: int,
393
    num_heads: tuple[int, int],
394
395
    head_size: int,
    dtype: torch.dtype,
396
    seed: int,
397
    device: str,
398
    use_alibi: bool = False,
399
) -> None:
400
    current_platform.seed_everything(seed)
401
    torch.set_default_device(device)
402
403
404
405
406
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
407
408
    num_tokens = sum(seq_lens)

409
    scale = float(1.0 / (head_size**0.5))
410
    num_query_heads, num_kv_heads = num_heads
411
    qkv = torch.empty(num_tokens,
412
                      num_query_heads + 2 * num_kv_heads,
413
                      head_size,
414
                      dtype=dtype)
415
416
417
418
419
420
421
422
423
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    alibi_bias = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
        attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
                                     seq_lens)
        output = torch.empty_like(query)
        start = 0
        # Dynamic sequence length not supported with custom attn_bias.
        for i, seq_len in enumerate(seq_lens):
            end = start + seq_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=attn_bias[i],
                p=0.0,
                scale=scale)
            output[start:end].copy_(out.view_as(query[start:end]))
            start += seq_len
        # xformers.AttentionBias to Tensor for use in reference impl.
        alibi_bias = [
            b.materialize(b.shape, device=device).squeeze() for b in attn_bias
        ]
    else:
        attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
        output = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=attn_bias,
            p=0.0,
            scale=scale,
        )
        output = output.squeeze(0)
458

459
460
461
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
462
463
464
465
466
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
467
        scale,
468
        alibi_bias,
469
470
        dtype,
    )
471
472
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
473
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
    num_seqs: int,
    num_heads: tuple[int, int],
    head_size: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    return test_multi_query_kv_attention(
        num_seqs,
        num_heads,
        head_size,
        dtype,
        seed,
        device,
        use_alibi=True,
    )