test_attention.py 17.6 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
6
import torch

7
from vllm import _custom_ops as ops
8
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
9

10
from .allclose_default import get_default_atol, get_default_rtol
11
from .utils import torch_version
12

13
14
15
16
if not is_hip():
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

17
18
19
20
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
21
22
23
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
24
PARTITION_SIZE = 512
25
26
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
27
28
        #   ] if not is_hip() else [torch.half, torch.bfloat16]
        ] if not is_hip() else [torch.half]
29
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
30
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
31
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
32
33
34

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
35
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
36

37
BLOCK_SIZES = [16, 32]
38
USE_ALIBI = [False, True]
zhuwenwen's avatar
zhuwenwen committed
39
KV_CACHE_DTYPE = ["auto", "fp8"] if not is_hip() else ["auto"]
40
SEEDS = [0]
41
42
43
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
44

45
46
47
48
49
50
51
52

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
53
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
54
    if attn_mask is not None:
55
56
57
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
58
59
60
61
62
63
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
64
    num_queries_per_kv: int,
65
66
67
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
68
    seq_lens: torch.Tensor,
69
70
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
71
) -> None:
72
73
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
74
75
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
76
    num_seqs = query.shape[0]
77

78
79
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
80
    for i in range(num_seqs):
81
        q = query[i].unsqueeze(0)
82
83
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
84

85
86
        keys_lst: List[torch.Tensor] = []
        values_lst: List[torch.Tensor] = []
87
        for j in range(seq_len):
88
89
90
91
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
92
            k = k.reshape(num_kv_heads, head_size)
93
            keys_lst.append(k)
94
95

            v = value_cache[block_number, :, :, block_offset]
96
97
98
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
99
100
101
102
103
104
105
106
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
107
108
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
109
110
111
112
113
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
114
115
116
        output[i].copy_(out, non_blocking=True)


117
@pytest.mark.parametrize(
118
119
    # "version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
    "version", ["v1", "v2"])
120
121
122
123
124
125
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
126
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
127
@pytest.mark.parametrize("seed", SEEDS)
128
@pytest.mark.parametrize("device", CUDA_DEVICES)
129
def test_paged_attention(
130
    kv_cache_factory,
131
    version: str,
132
133
    num_seqs: int,
    num_heads: Tuple[int, int],
134
    head_size: int,
135
    use_alibi: bool,
136
137
    block_size: int,
    dtype: torch.dtype,
138
    kv_cache_dtype: str,
139
    seed: int,
140
    device: str,
141
) -> None:
142
143
    if ((kv_cache_dtype == "fp8" and head_size % 16)
            or (version == "rocm" and head_size not in (64, 128))):
Joe's avatar
Joe committed
144
        pytest.skip()
145
146

    seed_everything(seed)
147
    torch.set_default_device(device)
148
149
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
150
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
151
152
153
154
155
156
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
157
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
158

159
160
161
162
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
163

164
    # Create the block tables.
165
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
166
    block_tables_lst: List[List[int]] = []
167
    for _ in range(num_seqs):
168
        block_table = [
169
            random.randint(0, NUM_BLOCKS - 1)
170
171
            for _ in range(max_num_blocks_per_seq)
        ]
172
173
174
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
175

176
177
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
178
179
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
180
                                                device)
181
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
182

183
    # Using default kv_scale
184
    k_scale = v_scale = 1.0
185

186
187
    # Call the paged attention kernel.
    output = torch.empty_like(query)
188
    if version == "v1":
189
190
        if torch_version.startswith("2.3"):
            ops.paged_attention_v1(
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )
206
207
208
        elif torch_version.startswith("2.4"):  
            from tests.kernels.utils import opcheck 
            ops.paged_attention_v1(
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

225
226
227
            opcheck(torch.ops._C.paged_attention_v1,
                    (output, query, key_cache, value_cache, num_kv_heads, scale,
                    block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
228
                    kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
229
                    cond=(head_size == HEAD_SIZES[0]
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
                        and block_size == BLOCK_SIZES[0]))
        else:
            print(f"PyTorch version {torch_version} is not specifically handled.")

    elif version in ("v2", "rocm"):
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
        if version == "v2":
            if torch_version.startswith("2.3"):
                ops.paged_attention_v2(
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                )
            elif torch_version.startswith("2.4"):
                from tests.kernels.utils import opcheck 
                ops.paged_attention_v2(
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                )

                opcheck(torch.ops._C.paged_attention_v2,
                        (output, exp_sums, max_logits, tmp_output, query,
                        key_cache, value_cache, num_kv_heads, scale, block_tables,
                        seq_lens, block_size, max_seq_len, alibi_slopes,
294
                        kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
                        cond=(head_size == HEAD_SIZES[0]
                            and block_size == BLOCK_SIZES[0]))
            else:
                print(f"PyTorch version {torch_version} is not specifically handled.")

        else:
            if torch_version.startswith("2.3"):
                ops.paged_attention_rocm(
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                )
            elif torch_version.startswith("2.4"):
                from tests.kernels.utils import opcheck 
                ops.paged_attention_rocm(
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                )

                opcheck(torch.ops._rocm_C.paged_attention,
                        (output, exp_sums, max_logits, tmp_output, query,
                        key_cache, value_cache, num_kv_heads, scale, block_tables,
                        seq_lens, block_size, max_seq_len, alibi_slopes,
                        kv_cache_dtype, k_scale, v_scale),
                        cond=(head_size == HEAD_SIZES[0]
                            and block_size == BLOCK_SIZES[0]))
            else:
                print(f"PyTorch version {torch_version} is not specifically handled.")
352

353
    else:
354
        raise AssertionError(f"Unknown version: {version}")
355

356
    # Run the reference implementation.
357
    if kv_cache_dtype == "fp8":
358
359
360
361
362
363
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
364
                                            device=device)
365
        ops.convert_fp8(dequantized_key_cache, key_cache)
366
367
368
369
370
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
371
                                              device=device)
372
        ops.convert_fp8(dequantized_value_cache, value_cache)
373
374
        value_cache = dequantized_value_cache

375
376
377
378
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
379
        num_queries_per_kv,
380
381
382
        key_cache,
        value_cache,
        block_tables,
383
        seq_lens,
384
385
        scale,
        alibi_slopes,
386
    )
387
388
389
390

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
391
392
393
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5

394
395
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
396
397
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
398
        atol, rtol = 1e-2, 1e-5
399
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
400
401


402
403
404
405
406
407
408
409
410
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
411
    ref_outputs: List[torch.Tensor] = []
412
413
414
415
416
417
418
419
420
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
421
        attn_mask = attn_mask.to(dtype=dtype)
422
423
424
425
426
427
428
429
430

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
431
432

    return torch.cat(ref_outputs, dim=0)
433
434


435
# TODO(woosuk): Add tests for USE_ALIBI=True.
436
437
438
439
440
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
441
@pytest.mark.parametrize("device", CUDA_DEVICES)
442
443
@pytest.mark.skipif(is_hip(),
                    reason="Xformers backend is not supported on ROCm.")
444
@torch.inference_mode()
445
def test_multi_query_kv_attention(
446
    num_seqs: int,
447
    num_heads: Tuple[int, int],
448
449
    head_size: int,
    dtype: torch.dtype,
450
    seed: int,
451
    device: str,
452
) -> None:
453
    seed_everything(seed)
454
    torch.set_default_device(device)
455
456
457
458
459
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
460
461
    num_tokens = sum(seq_lens)

462
    scale = float(1.0 / (head_size**0.5))
463
    num_query_heads, num_kv_heads = num_heads
464
    qkv = torch.empty(num_tokens,
465
                      num_query_heads + 2 * num_kv_heads,
466
                      head_size,
467
                      dtype=dtype)
468
469
470
471
472
473
474
475
476
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
477
478
479
480
481
482
483
484
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
485
486
        op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
        (is_hip()) else None,
Woosuk Kwon's avatar
Woosuk Kwon committed
487
    )
488
    output = output.squeeze(0)
489

490
491
492
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
493
494
495
496
497
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
498
        scale,
499
500
        dtype,
    )
501
502
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5
503
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)