test_attention.py 17.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import random
4
from typing import Optional
5

6
import pytest
7
8
import torch

9
from tests.kernels.utils import opcheck
10
from vllm import _custom_ops as ops
11
from vllm.platforms import current_platform
12
from vllm.utils import get_max_shared_memory_bytes
13

14
15
from .allclose_default import get_default_atol, get_default_rtol

16
if not current_platform.is_rocm():
17
18
19
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

20
21
    from vllm.attention.backends.xformers import _make_alibi_bias

22
23
24
25
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
26
27
28
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
29
PARTITION_SIZE = 512
30
PARTITION_SIZE_ROCM = 256
31
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
32
33
34
DTYPES = [
    torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
35
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
36
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
37
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
38

39
40
41
# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
42

43
BLOCK_SIZES = [16, 32]
44
USE_ALIBI = [False, True]
45
KV_CACHE_DTYPE = ["auto", "fp8"]
46
SEEDS = [0]
47
48
49
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
50

51
52
53
54
55
56
57
58

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
59
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
60
    if attn_mask is not None:
61
62
63
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
64
65
66
67
68
69
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
70
    num_queries_per_kv: int,
71
72
73
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
74
    seq_lens: torch.Tensor,
75
76
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
77
) -> None:
78
79
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
80
81
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
82
    num_seqs = query.shape[0]
83

84
85
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
86
    for i in range(num_seqs):
87
        q = query[i].unsqueeze(0)
88
89
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
90

91
92
        keys_lst: list[torch.Tensor] = []
        values_lst: list[torch.Tensor] = []
93
        for j in range(seq_len):
94
95
96
97
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
98
            k = k.reshape(num_kv_heads, head_size)
99
            keys_lst.append(k)
100
101

            v = value_cache[block_number, :, :, block_offset]
102
103
104
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
105
106
107
108
109
110
111
112
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
113
114
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
115
116
117
118
119
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
120
121
122
        output[i].copy_(out, non_blocking=True)


123
@pytest.mark.parametrize(
124
125
    "version",
    ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
126
127
128
129
130
131
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
132
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
133
@pytest.mark.parametrize("seed", SEEDS)
134
@pytest.mark.parametrize("device", CUDA_DEVICES)
135
def test_paged_attention(
136
    kv_cache_factory,
137
    version: str,
138
    num_seqs: int,
139
    num_heads: tuple[int, int],
140
    head_size: int,
141
    use_alibi: bool,
142
143
    block_size: int,
    dtype: torch.dtype,
144
    kv_cache_dtype: str,
145
    seed: int,
146
    device: str,
147
) -> None:
148
149
    if ((kv_cache_dtype == "fp8" and head_size % 16)
            or (version == "rocm" and head_size not in (64, 128))):
Joe's avatar
Joe committed
150
        pytest.skip()
151

152
153
    global PARTITION_SIZE

154
    current_platform.seed_everything(seed)
155
    torch.set_default_device(device)
156
157
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
158
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
159
160
161
162
163
164
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
165
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
166

167
168
169
170
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
171

172
    # Create the block tables.
173
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
174
    block_tables_lst: list[list[int]] = []
175
    for _ in range(num_seqs):
176
        block_table = [
177
            random.randint(0, NUM_BLOCKS - 1)
178
179
            for _ in range(max_num_blocks_per_seq)
        ]
180
181
182
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
183

184
185
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
186
187
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
188
                                                device)
189
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
190

191
    # Using default kv_scale
192
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
193

194
195
    # Call the paged attention kernel.
    output = torch.empty_like(query)
196
    if version == "v1":
197
        ops.paged_attention_v1(
198
199
200
201
            output,
            query,
            key_cache,
            value_cache,
202
            num_kv_heads,
203
204
            scale,
            block_tables,
205
            seq_lens,
206
            block_size,
207
            max_seq_len,
208
            alibi_slopes,
209
            kv_cache_dtype,
210
211
            k_scale,
            v_scale,
212
        )
213
214
215
216
217

        opcheck(torch.ops._C.paged_attention_v1,
                (output, query, key_cache, value_cache, num_kv_heads, scale,
                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
218
219
                cond=(head_size == HEAD_SIZES[0]
                      and block_size == BLOCK_SIZES[0]))
220

221
    elif version in ("v2", "rocm"):
222
223
224
        if current_platform.is_rocm() and version == "rocm":
            PARTITION_SIZE = PARTITION_SIZE_ROCM

225
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
226
227
228
229
230
231
232
233
234
235
236
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        if version == "v2":
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._C.paged_attention_v2,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
263
264
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

        else:
            ops.paged_attention_rocm(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._rocm_C.paged_attention,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale),
292
293
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
294

295
    else:
296
        raise AssertionError(f"Unknown version: {version}")
297

298
    # Run the reference implementation.
299
    if kv_cache_dtype == "fp8":
300
301
302
303
304
305
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
306
                                            device=device)
307
        ops.convert_fp8(dequantized_key_cache, key_cache)
308
309
310
311
312
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
313
                                              device=device)
314
        ops.convert_fp8(dequantized_value_cache, value_cache)
315
316
        value_cache = dequantized_value_cache

317
318
319
320
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
321
        num_queries_per_kv,
322
323
324
        key_cache,
        value_cache,
        block_tables,
325
        seq_lens,
326
327
        scale,
        alibi_slopes,
328
    )
329
330
331
332

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
333
334
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
335

336
337
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
338
339
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
340
        atol, rtol = 1e-2, 1e-5
341
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
342
343


344
def ref_multi_query_kv_attention(
345
    cu_seq_lens: list[int],
346
347
348
349
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
350
    alibi_bias: Optional[list[torch.Tensor]],
351
352
353
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
354
    ref_outputs: list[torch.Tensor] = []
355
356
    if alibi_bias:
        assert len(alibi_bias) == num_seqs
357
358
359
360
361
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

362
363
364
365
366
367
368
369
        # Create attention mask. ALiBi already includes a tril causal mask.
        if alibi_bias:
            attn_mask = alibi_bias[i]
        else:
            attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                                   diagonal=1)
            attn_mask = attn_mask * torch.finfo(dtype).min
            attn_mask = attn_mask.to(dtype=dtype)
370
371
372
373
374
375
376
377
378

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
379
380

    return torch.cat(ref_outputs, dim=0)
381
382
383
384
385
386
387


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
388
@pytest.mark.parametrize("device", CUDA_DEVICES)
389
@pytest.mark.skipif(current_platform.is_rocm(),
390
                    reason="Xformers backend is not supported on ROCm.")
391
@torch.inference_mode()
392
def test_multi_query_kv_attention(
393
    num_seqs: int,
394
    num_heads: tuple[int, int],
395
396
    head_size: int,
    dtype: torch.dtype,
397
    seed: int,
398
    device: str,
399
    use_alibi: bool = False,
400
) -> None:
401
    current_platform.seed_everything(seed)
402
    torch.set_default_device(device)
403
404
405
406
407
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
408
409
    num_tokens = sum(seq_lens)

410
    scale = float(1.0 / (head_size**0.5))
411
    num_query_heads, num_kv_heads = num_heads
412
    qkv = torch.empty(num_tokens,
413
                      num_query_heads + 2 * num_kv_heads,
414
                      head_size,
415
                      dtype=dtype)
416
417
418
419
420
421
422
423
424
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    alibi_bias = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
        attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
                                     seq_lens)
        output = torch.empty_like(query)
        start = 0
        # Dynamic sequence length not supported with custom attn_bias.
        for i, seq_len in enumerate(seq_lens):
            end = start + seq_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=attn_bias[i],
                p=0.0,
                scale=scale)
            output[start:end].copy_(out.view_as(query[start:end]))
            start += seq_len
        # xformers.AttentionBias to Tensor for use in reference impl.
        alibi_bias = [
            b.materialize(b.shape, device=device).squeeze() for b in attn_bias
        ]
    else:
        attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
        output = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=attn_bias,
            p=0.0,
            scale=scale,
        )
        output = output.squeeze(0)
459

460
461
462
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
463
464
465
466
467
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
468
        scale,
469
        alibi_bias,
470
471
        dtype,
    )
472
473
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)


@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
    num_seqs: int,
    num_heads: tuple[int, int],
    head_size: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    return test_multi_query_kv_attention(
        num_seqs,
        num_heads,
        head_size,
        dtype,
        seed,
        device,
        use_alibi=True,
    )