"vllm/vscode:/vscode.git/clone" did not exist on "4ae3fc0419f311554ee7d35272fb5931b5a42121"
test_attention.py 15.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import random
4
from typing import Optional
5

6
import pytest
7
8
import torch

9
from tests.kernels.utils import opcheck
10
from vllm import _custom_ops as ops
11
from vllm.platforms import current_platform
12
from vllm.utils import get_max_shared_memory_bytes
13

14
15
from .allclose_default import get_default_atol, get_default_rtol

16
if not current_platform.is_rocm():
17
18
19
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

20
21
22
23
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
24
25
26
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
27
PARTITION_SIZE = 512
28
PARTITION_SIZE_ROCM = 256
29
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
30
31
32
DTYPES = [
    torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
33
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
34
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
35
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
36

37
38
39
# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
40

41
BLOCK_SIZES = [16, 32]
42
USE_ALIBI = [False, True]
43
KV_CACHE_DTYPE = ["auto", "fp8"]
44
SEEDS = [0]
45
46
47
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
48

49
50
51
52
53
54
55
56

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
57
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
58
    if attn_mask is not None:
59
60
61
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
62
63
64
65
66
67
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
68
    num_queries_per_kv: int,
69
70
71
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
72
    seq_lens: torch.Tensor,
73
74
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
75
) -> None:
76
77
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
78
79
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
80
    num_seqs = query.shape[0]
81

82
83
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
84
    for i in range(num_seqs):
85
        q = query[i].unsqueeze(0)
86
87
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
88

89
90
        keys_lst: list[torch.Tensor] = []
        values_lst: list[torch.Tensor] = []
91
        for j in range(seq_len):
92
93
94
95
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
96
            k = k.reshape(num_kv_heads, head_size)
97
            keys_lst.append(k)
98
99

            v = value_cache[block_number, :, :, block_offset]
100
101
102
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
103
104
105
106
107
108
109
110
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
111
112
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
113
114
115
116
117
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
118
119
120
        output[i].copy_(out, non_blocking=True)


121
@pytest.mark.parametrize(
122
123
    "version",
    ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
124
125
126
127
128
129
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
130
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
131
@pytest.mark.parametrize("seed", SEEDS)
132
@pytest.mark.parametrize("device", CUDA_DEVICES)
133
def test_paged_attention(
134
    kv_cache_factory,
135
    version: str,
136
    num_seqs: int,
137
    num_heads: tuple[int, int],
138
    head_size: int,
139
    use_alibi: bool,
140
141
    block_size: int,
    dtype: torch.dtype,
142
    kv_cache_dtype: str,
143
    seed: int,
144
    device: str,
145
) -> None:
146
147
    if ((kv_cache_dtype == "fp8" and head_size % 16)
            or (version == "rocm" and head_size not in (64, 128))):
Joe's avatar
Joe committed
148
        pytest.skip()
149

150
151
    global PARTITION_SIZE

152
    current_platform.seed_everything(seed)
153
    torch.set_default_device(device)
154
155
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
156
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
157
158
159
160
161
162
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
163
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
164

165
166
167
168
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
169

170
    # Create the block tables.
171
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
172
    block_tables_lst: list[list[int]] = []
173
    for _ in range(num_seqs):
174
        block_table = [
175
            random.randint(0, NUM_BLOCKS - 1)
176
177
            for _ in range(max_num_blocks_per_seq)
        ]
178
179
180
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
181

182
183
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
184
185
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
186
                                                device)
187
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
188

189
    # Using default kv_scale
190
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
191

192
193
    # Call the paged attention kernel.
    output = torch.empty_like(query)
194
    if version == "v1":
195
        ops.paged_attention_v1(
196
197
198
199
            output,
            query,
            key_cache,
            value_cache,
200
            num_kv_heads,
201
202
            scale,
            block_tables,
203
            seq_lens,
204
            block_size,
205
            max_seq_len,
206
            alibi_slopes,
207
            kv_cache_dtype,
208
209
            k_scale,
            v_scale,
210
        )
211
212
213
214
215

        opcheck(torch.ops._C.paged_attention_v1,
                (output, query, key_cache, value_cache, num_kv_heads, scale,
                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
216
217
                cond=(head_size == HEAD_SIZES[0]
                      and block_size == BLOCK_SIZES[0]))
218

219
    elif version in ("v2", "rocm"):
220
221
222
        if current_platform.is_rocm() and version == "rocm":
            PARTITION_SIZE = PARTITION_SIZE_ROCM

223
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
224
225
226
227
228
229
230
231
232
233
234
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        if version == "v2":
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._C.paged_attention_v2,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
261
262
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

        else:
            ops.paged_attention_rocm(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._rocm_C.paged_attention,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale),
290
291
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
292

293
    else:
294
        raise AssertionError(f"Unknown version: {version}")
295

296
    # Run the reference implementation.
297
    if kv_cache_dtype == "fp8":
298
299
300
301
302
303
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
304
                                            device=device)
305
        ops.convert_fp8(dequantized_key_cache, key_cache)
306
307
308
309
310
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
311
                                              device=device)
312
        ops.convert_fp8(dequantized_value_cache, value_cache)
313
314
        value_cache = dequantized_value_cache

315
316
317
318
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
319
        num_queries_per_kv,
320
321
322
        key_cache,
        value_cache,
        block_tables,
323
        seq_lens,
324
325
        scale,
        alibi_slopes,
326
    )
327
328
329
330

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
331
332
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
333

334
335
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
336
337
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
338
        atol, rtol = 1e-2, 1e-5
339
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
340
341


342
def ref_multi_query_kv_attention(
343
    cu_seq_lens: list[int],
344
345
346
347
348
349
350
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
351
    ref_outputs: list[torch.Tensor] = []
352
353
354
355
356
357
358
359
360
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
361
        attn_mask = attn_mask.to(dtype=dtype)
362
363
364
365
366
367
368
369
370

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
371
372

    return torch.cat(ref_outputs, dim=0)
373
374


375
# TODO(woosuk): Add tests for USE_ALIBI=True.
376
377
378
379
380
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
381
@pytest.mark.parametrize("device", CUDA_DEVICES)
382
@pytest.mark.skipif(current_platform.is_rocm(),
383
                    reason="Xformers backend is not supported on ROCm.")
384
@torch.inference_mode()
385
def test_multi_query_kv_attention(
386
    num_seqs: int,
387
    num_heads: tuple[int, int],
388
389
    head_size: int,
    dtype: torch.dtype,
390
    seed: int,
391
    device: str,
392
) -> None:
393
    current_platform.seed_everything(seed)
394
    torch.set_default_device(device)
395
396
397
398
399
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
400
401
    num_tokens = sum(seq_lens)

402
    scale = float(1.0 / (head_size**0.5))
403
    num_query_heads, num_kv_heads = num_heads
404
    qkv = torch.empty(num_tokens,
405
                      num_query_heads + 2 * num_kv_heads,
406
                      head_size,
407
                      dtype=dtype)
408
409
410
411
412
413
414
415
416
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
417
418
419
420
421
422
423
424
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
425
    )
426
    output = output.squeeze(0)
427

428
429
430
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
431
432
433
434
435
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
436
        scale,
437
438
        dtype,
    )
439
440
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
441
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)