test_attention.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import random

6
import pytest
7
8
import torch

9
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
10
from tests.kernels.utils import opcheck
11
from vllm import _custom_ops as ops
12
from vllm.attention.layer import Attention, MultiHeadAttention
13
from vllm.platforms import current_platform
14
from vllm.utils.mem_utils import get_max_shared_memory_bytes
15

zhuwenwen's avatar
zhuwenwen committed
16
17
18
if current_platform.is_rocm():
    from flash_attn import vllm_flash_attn_with_kvcache

19
20
21
22
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
23
24
25
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
26
PARTITION_SIZE = 512
27
PARTITION_SIZE_ROCM = 256
28
DTYPES = [torch.bfloat16]
29
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
30
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
31
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
32

33
34
# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
35
HEAD_SIZES = [32, 80, 128, 256]
36

37
BLOCK_SIZES = [16, 32]
38
USE_ALIBI = [False, True]
zhuwenwen's avatar
zhuwenwen committed
39
KV_CACHE_DTYPE = ["auto", "fp8"] if not current_platform.is_rocm() else ["auto"]
40
SEEDS = [0]
41
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
42

43
44
45
46
47
48

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
49
    attn_mask: torch.Tensor | None = None,
50
) -> torch.Tensor:
51
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
52
    if attn_mask is not None:
53
54
55
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
56
57
58
59
60
61
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
62
    num_queries_per_kv: int,
63
64
65
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
66
    seq_lens: torch.Tensor,
67
    scale: float,
68
    alibi_slopes: torch.Tensor | None,
69
) -> None:
70
71
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
72
73
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
74
    num_seqs = query.shape[0]
75

76
77
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
78
    for i in range(num_seqs):
79
        q = query[i].unsqueeze(0)
80
81
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
82

83
84
        keys_lst: list[torch.Tensor] = []
        values_lst: list[torch.Tensor] = []
85
        for j in range(seq_len):
86
87
88
89
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
90
            k = k.reshape(num_kv_heads, head_size)
91
            keys_lst.append(k)
92
93

            v = value_cache[block_number, :, :, block_offset]
94
95
96
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
97
98
99
100
101
102
103
104
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
105
106
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
107
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
108
109
110

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
111
112
113
        output[i].copy_(out, non_blocking=True)


114
@pytest.mark.parametrize(
115
    "version", ["v1", "v2"] 
116
)
117
118
119
120
121
122
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
123
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
124
@pytest.mark.parametrize("seed", SEEDS)
125
@pytest.mark.parametrize("device", CUDA_DEVICES)
126
def test_paged_attention(
127
    kv_cache_factory,
128
    version: str,
129
    num_seqs: int,
130
    num_heads: tuple[int, int],
131
    head_size: int,
132
    use_alibi: bool,
133
134
    block_size: int,
    dtype: torch.dtype,
135
    kv_cache_dtype: str,
136
    seed: int,
137
    device: str,
138
) -> None:
139
140
141
    if (kv_cache_dtype == "fp8" and head_size % 16) or (
        version == "rocm" and head_size not in (64, 128)
    ):
Joe's avatar
Joe committed
142
        pytest.skip()
143

144
145
146
147
148
149
150
    if (
        version == "rocm"
        and current_platform.is_navi()
        and (
            kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi
        )
    ):
151
152
        pytest.skip()

153
154
    global PARTITION_SIZE

155
    current_platform.seed_everything(seed)
156
    torch.set_default_device(device)
157
158
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
159
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
160
161
162
163
164
165
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
166
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
167

168
169
170
171
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
172

173
    # Create the block tables.
174
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
175
    block_tables_lst: list[list[int]] = []
176
    for _ in range(num_seqs):
177
        block_table = [
178
            random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
179
        ]
180
181
182
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
183

184
    # Create the KV caches.
185
186
187
188
189
190
191
192
193
194
195
    key_caches, value_caches = kv_cache_factory(
        NUM_BLOCKS,
        block_size,
        1,
        num_kv_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        device,
    )
196
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
197

198
    # Using default kv_scale
199
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
200

201
202
    # Call the paged attention kernel.
    output = torch.empty_like(query)
203
    if version == "v1":
204
        ops.paged_attention_v1(
205
206
207
208
            output,
            query,
            key_cache,
            value_cache,
209
            num_kv_heads,
210
211
            scale,
            block_tables,
212
            seq_lens,
213
            block_size,
214
            max_seq_len,
215
            alibi_slopes,
216
            kv_cache_dtype,
217
218
            k_scale,
            v_scale,
219
        )
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        opcheck(
            torch.ops._C.paged_attention_v1,
            (
                output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
                0,
                0,
                0,
                64,
                0,
            ),
            cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
        )
246

247
    elif version in ("v2", "rocm"):
248
249
250
        if current_platform.is_rocm() and version == "rocm":
            PARTITION_SIZE = PARTITION_SIZE_ROCM

251
        num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
252
253
254
255
256
257
258
259
260
261
262
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        if version == "v2":
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            opcheck(
                torch.ops._C.paged_attention_v2,
                (
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                    0,
                    0,
                    0,
                    64,
                    0,
                ),
                cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
            )
312
313
314
315
316
317
318
319
320
321
322
323
324
325

        else:
            ops.paged_attention_rocm(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
326
                None,
327
328
329
330
331
332
333
334
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            opcheck(
                torch.ops._rocm_C.paged_attention,
                (
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    None,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                ),
                cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
            )
359

360
    else:
361
        raise AssertionError(f"Unknown version: {version}")
362

363
    # Run the reference implementation.
364
    if kv_cache_dtype == "fp8":
365
366
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
367
368
369
370
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
        dequantized_key_cache = torch.empty(
            size=key_cache_shape, dtype=dtype, device=device
        )
371
        ops.convert_fp8(dequantized_key_cache, key_cache)
372
373
374
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
375
376
377
        dequantized_value_cache = torch.empty(
            size=value_cache_shape, dtype=dtype, device=device
        )
378
        ops.convert_fp8(dequantized_value_cache, value_cache)
379
380
        value_cache = dequantized_value_cache

381
382
383
384
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
385
        num_queries_per_kv,
386
387
388
        key_cache,
        value_cache,
        block_tables,
389
        seq_lens,
390
391
        scale,
        alibi_slopes,
392
    )
393
394
395
396

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
397
398
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
399

400
401
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
402
403
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
404
        atol, rtol = 1e-2, 1e-5
405
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
406
407


408
def ref_multi_query_kv_attention(
409
    cu_seq_lens: list[int],
410
411
412
413
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
414
    alibi_bias: list[torch.Tensor] | None,
415
416
417
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
418
    ref_outputs: list[torch.Tensor] = []
419
420
    if alibi_bias:
        assert len(alibi_bias) == num_seqs
421
422
423
424
425
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

426
427
428
429
        # Create attention mask. ALiBi already includes a tril causal mask.
        if alibi_bias:
            attn_mask = alibi_bias[i]
        else:
430
431
432
            attn_mask = torch.triu(
                torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1
            )
433
434
            attn_mask = attn_mask * torch.finfo(dtype).min
            attn_mask = attn_mask.to(dtype=dtype)
435
436
437
438
439
440
441
442
443

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
444
445

    return torch.cat(ref_outputs, dim=0)
446
447


448
449
450
451
452
453
454
455
456
457
458
459
460
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
    head_size = 64
    scale = float(1.0 / (head_size**0.5))
    num_heads = 16
    num_kv_heads = 5
    with pytest.raises(AssertionError):
        _ = attention_cls(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
        )