test_attention.py 14.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import random

6
import pytest
7
8
import torch

9
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
10
from tests.kernels.utils import opcheck
11
from vllm import _custom_ops as ops
12
from vllm.attention.layer import Attention
13
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
14
from vllm.platforms import current_platform
15
from vllm.utils.mem_utils import get_max_shared_memory_bytes
16
from vllm.utils.torch_utils import set_random_seed
17

zhuwenwen's avatar
zhuwenwen committed
18
19
20
if current_platform.is_rocm():
    from flash_attn import vllm_flash_attn_with_kvcache

21
22
23
24
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
25
26
27
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
28
PARTITION_SIZE = 512
29
PARTITION_SIZE_ROCM = 256
30
DTYPES = [torch.bfloat16]
31
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
32
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
33
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
34

35
# This should be sync with get_supported_head_sizes() in
36
# vllm.v1.attention.ops.paged_attn.PagedAttention
37
HEAD_SIZES = [32, 80, 128, 256]
38

39
BLOCK_SIZES = [16, 32]
40
USE_ALIBI = [False, True]
zhuwenwen's avatar
zhuwenwen committed
41
KV_CACHE_DTYPE = ["auto", "fp8"] if not current_platform.is_rocm() else ["auto"]
42
SEEDS = [0]
43
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
44

45
46
47
48
49
50

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
51
    attn_mask: torch.Tensor | None = None,
52
) -> torch.Tensor:
53
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
54
    if attn_mask is not None:
55
56
57
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
58
59
60
61
62
63
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
64
    num_queries_per_kv: int,
65
66
67
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
68
    seq_lens: torch.Tensor,
69
    scale: float,
70
    alibi_slopes: torch.Tensor | None,
71
) -> None:
72
73
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
74
75
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
76
    num_seqs = query.shape[0]
77

78
79
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
80
    for i in range(num_seqs):
81
        q = query[i].unsqueeze(0)
82
83
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
84

85
86
        keys_lst: list[torch.Tensor] = []
        values_lst: list[torch.Tensor] = []
87
        for j in range(seq_len):
88
89
90
91
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
92
            k = k.reshape(num_kv_heads, head_size)
93
            keys_lst.append(k)
94
95

            v = value_cache[block_number, :, :, block_offset]
96
97
98
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
99
100
101
102
103
104
105
106
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
107
108
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
109
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
110
111
112

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
113
114
115
        output[i].copy_(out, non_blocking=True)


116
@pytest.mark.parametrize(
117
    "version", ["v1", "v2"] 
118
)
119
120
121
122
123
124
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
125
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
126
@pytest.mark.parametrize("seed", SEEDS)
127
@pytest.mark.parametrize("device", CUDA_DEVICES)
128
def test_paged_attention(
129
    kv_cache_factory,
130
    version: str,
131
    num_seqs: int,
132
    num_heads: tuple[int, int],
133
    head_size: int,
134
    use_alibi: bool,
135
136
    block_size: int,
    dtype: torch.dtype,
137
    kv_cache_dtype: str,
138
    seed: int,
139
    device: str,
140
) -> None:
141
142
143
    if (kv_cache_dtype == "fp8" and head_size % 16) or (
        version == "rocm" and head_size not in (64, 128)
    ):
Joe's avatar
Joe committed
144
        pytest.skip()
145

146
147
148
149
150
151
152
    if (
        version == "rocm"
        and current_platform.is_navi()
        and (
            kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi
        )
    ):
153
154
        pytest.skip()

155
156
    global PARTITION_SIZE

157
    set_random_seed(seed)
158
    torch.set_default_device(device)
159
160
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
161
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
162
163
164
165
166
167
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
168
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
169

170
171
172
173
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
174

175
    # Create the block tables.
176
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
177
    block_tables_lst: list[list[int]] = []
178
    for _ in range(num_seqs):
179
        block_table = [
180
            random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
181
        ]
182
183
184
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
185

186
    # Create the KV caches.
187
188
189
190
191
192
193
194
195
196
197
    key_caches, value_caches = kv_cache_factory(
        NUM_BLOCKS,
        block_size,
        1,
        num_kv_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        device,
    )
198
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
199

200
    # Using default kv_scale
201
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
202

203
204
    # Call the paged attention kernel.
    output = torch.empty_like(query)
205
    if version == "v1":
206
        ops.paged_attention_v1(
207
208
209
210
            output,
            query,
            key_cache,
            value_cache,
211
            num_kv_heads,
212
213
            scale,
            block_tables,
214
            seq_lens,
215
            block_size,
216
            max_seq_len,
217
            alibi_slopes,
218
            kv_cache_dtype,
219
220
            k_scale,
            v_scale,
221
        )
222

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        opcheck(
            torch.ops._C.paged_attention_v1,
            (
                output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
                0,
                0,
                0,
                64,
                0,
            ),
            cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
        )
248

249
    elif version in ("v2", "rocm"):
250
251
252
        if current_platform.is_rocm() and version == "rocm":
            PARTITION_SIZE = PARTITION_SIZE_ROCM

253
        num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
254
255
256
257
258
259
260
261
262
263
264
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        if version == "v2":
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
            opcheck(
                torch.ops._C.paged_attention_v2,
                (
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                    0,
                    0,
                    0,
                    64,
                    0,
                ),
                cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
            )
314
315
316
317
318
319
320
321
322
323
324
325
326
327

        else:
            ops.paged_attention_rocm(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
328
                None,
329
330
331
332
333
334
335
336
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
            opcheck(
                torch.ops._rocm_C.paged_attention,
                (
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    None,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                ),
                cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
            )
361

362
    else:
363
        raise AssertionError(f"Unknown version: {version}")
364

365
    # Run the reference implementation.
366
    if kv_cache_dtype == "fp8":
367
368
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
369
370
371
372
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
        dequantized_key_cache = torch.empty(
            size=key_cache_shape, dtype=dtype, device=device
        )
373
        ops.convert_fp8(dequantized_key_cache, key_cache)
374
375
376
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
377
378
379
        dequantized_value_cache = torch.empty(
            size=value_cache_shape, dtype=dtype, device=device
        )
380
        ops.convert_fp8(dequantized_value_cache, value_cache)
381
382
        value_cache = dequantized_value_cache

383
384
385
386
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
387
        num_queries_per_kv,
388
389
390
        key_cache,
        value_cache,
        block_tables,
391
        seq_lens,
392
393
        scale,
        alibi_slopes,
394
    )
395
396
397
398

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
399
400
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
401

402
403
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
404
405
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
406
        atol, rtol = 1e-2, 1e-5
407
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
408
409


410
def ref_multi_query_kv_attention(
411
    cu_seq_lens: list[int],
412
413
414
415
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
416
    alibi_bias: list[torch.Tensor] | None,
417
418
419
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
420
    ref_outputs: list[torch.Tensor] = []
421
422
    if alibi_bias:
        assert len(alibi_bias) == num_seqs
423
424
425
426
427
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

428
429
430
431
        # Create attention mask. ALiBi already includes a tril causal mask.
        if alibi_bias:
            attn_mask = alibi_bias[i]
        else:
432
433
434
            attn_mask = torch.triu(
                torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1
            )
435
436
            attn_mask = attn_mask * torch.finfo(dtype).min
            attn_mask = attn_mask.to(dtype=dtype)
437
438
439
440
441
442
443
444
445

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
446
447

    return torch.cat(ref_outputs, dim=0)
448
449


450
@pytest.mark.parametrize("attention_cls", [Attention, MMEncoderAttention])
451
452
453
454
455
456
457
458
459
460
461
462
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
    head_size = 64
    scale = float(1.0 / (head_size**0.5))
    num_heads = 16
    num_kv_heads = 5
    with pytest.raises(AssertionError):
        _ = attention_cls(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
        )