test_attention.py 15 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
6
import torch

7
from tests.kernels.utils import opcheck
8
from vllm import _custom_ops as ops
9
from vllm.platforms import current_platform
10
from vllm.utils import get_max_shared_memory_bytes
11

12
13
from .allclose_default import get_default_atol, get_default_rtol

14
if not current_platform.is_rocm():
15
16
17
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

18
19
20
21
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
22
23
24
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
25
PARTITION_SIZE = 512
26
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
27
28
29
DTYPES = [
    torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
30
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
31
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
32
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
33
34
35

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
36
HEAD_SIZES = [64, 80, 120, 256]
37

38
BLOCK_SIZES = [16, 32]
39
USE_ALIBI = [False, True]
40
KV_CACHE_DTYPE = ["auto", "fp8"]
41
SEEDS = [0]
42
43
44
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
45

46
47
48
49
50
51
52
53

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
54
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
55
    if attn_mask is not None:
56
57
58
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
59
60
61
62
63
64
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
65
    num_queries_per_kv: int,
66
67
68
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
69
    seq_lens: torch.Tensor,
70
71
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
72
) -> None:
73
74
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
75
76
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
77
    num_seqs = query.shape[0]
78

79
80
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
81
    for i in range(num_seqs):
82
        q = query[i].unsqueeze(0)
83
84
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
85

86
87
        keys_lst: List[torch.Tensor] = []
        values_lst: List[torch.Tensor] = []
88
        for j in range(seq_len):
89
90
91
92
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
93
            k = k.reshape(num_kv_heads, head_size)
94
            keys_lst.append(k)
95
96

            v = value_cache[block_number, :, :, block_offset]
97
98
99
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
100
101
102
103
104
105
106
107
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
108
109
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
110
111
112
113
114
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
115
116
117
        output[i].copy_(out, non_blocking=True)


118
@pytest.mark.parametrize(
119
120
    "version",
    ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
121
122
123
124
125
126
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
127
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
128
@pytest.mark.parametrize("seed", SEEDS)
129
@pytest.mark.parametrize("device", CUDA_DEVICES)
130
def test_paged_attention(
131
    kv_cache_factory,
132
    version: str,
133
134
    num_seqs: int,
    num_heads: Tuple[int, int],
135
    head_size: int,
136
    use_alibi: bool,
137
138
    block_size: int,
    dtype: torch.dtype,
139
    kv_cache_dtype: str,
140
    seed: int,
141
    device: str,
142
) -> None:
143
144
    if ((kv_cache_dtype == "fp8" and head_size % 16)
            or (version == "rocm" and head_size not in (64, 128))):
Joe's avatar
Joe committed
145
        pytest.skip()
146

147
    current_platform.seed_everything(seed)
148
    torch.set_default_device(device)
149
150
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
151
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
152
153
154
155
156
157
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
158
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
159

160
161
162
163
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
164

165
    # Create the block tables.
166
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
167
    block_tables_lst: List[List[int]] = []
168
    for _ in range(num_seqs):
169
        block_table = [
170
            random.randint(0, NUM_BLOCKS - 1)
171
172
            for _ in range(max_num_blocks_per_seq)
        ]
173
174
175
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
176

177
178
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
179
180
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
181
                                                device)
182
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
183

184
    # Using default kv_scale
185
    k_scale = v_scale = 1.0
186

187
188
    # Call the paged attention kernel.
    output = torch.empty_like(query)
189
    if version == "v1":
190
        ops.paged_attention_v1(
191
192
193
194
            output,
            query,
            key_cache,
            value_cache,
195
            num_kv_heads,
196
197
            scale,
            block_tables,
198
            seq_lens,
199
            block_size,
200
            max_seq_len,
201
            alibi_slopes,
202
            kv_cache_dtype,
203
204
            k_scale,
            v_scale,
205
        )
206
207
208
209
210

        opcheck(torch.ops._C.paged_attention_v1,
                (output, query, key_cache, value_cache, num_kv_heads, scale,
                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
211
212
                cond=(head_size == HEAD_SIZES[0]
                      and block_size == BLOCK_SIZES[0]))
213

214
    elif version in ("v2", "rocm"):
215
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
216
217
218
219
220
221
222
223
224
225
226
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        if version == "v2":
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._C.paged_attention_v2,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
253
254
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        else:
            ops.paged_attention_rocm(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._rocm_C.paged_attention,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale),
282
283
                    cond=(head_size == HEAD_SIZES[0]
                          and block_size == BLOCK_SIZES[0]))
284

285
    else:
286
        raise AssertionError(f"Unknown version: {version}")
287

288
    # Run the reference implementation.
289
    if kv_cache_dtype == "fp8":
290
291
292
293
294
295
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
296
                                            device=device)
297
        ops.convert_fp8(dequantized_key_cache, key_cache)
298
299
300
301
302
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
303
                                              device=device)
304
        ops.convert_fp8(dequantized_value_cache, value_cache)
305
306
        value_cache = dequantized_value_cache

307
308
309
310
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
311
        num_queries_per_kv,
312
313
314
        key_cache,
        value_cache,
        block_tables,
315
        seq_lens,
316
317
        scale,
        alibi_slopes,
318
    )
319
320
321
322

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
323
324
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
325

326
327
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
328
329
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
330
        atol, rtol = 1e-2, 1e-5
331
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
332
333


334
335
336
337
338
339
340
341
342
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
343
    ref_outputs: List[torch.Tensor] = []
344
345
346
347
348
349
350
351
352
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
353
        attn_mask = attn_mask.to(dtype=dtype)
354
355
356
357
358
359
360
361
362

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
363
364

    return torch.cat(ref_outputs, dim=0)
365
366


367
# TODO(woosuk): Add tests for USE_ALIBI=True.
368
369
370
371
372
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
373
@pytest.mark.parametrize("device", CUDA_DEVICES)
374
@pytest.mark.skipif(current_platform.is_rocm(),
375
                    reason="Xformers backend is not supported on ROCm.")
376
@torch.inference_mode()
377
def test_multi_query_kv_attention(
378
    num_seqs: int,
379
    num_heads: Tuple[int, int],
380
381
    head_size: int,
    dtype: torch.dtype,
382
    seed: int,
383
    device: str,
384
) -> None:
385
    current_platform.seed_everything(seed)
386
    torch.set_default_device(device)
387
388
389
390
391
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
392
393
    num_tokens = sum(seq_lens)

394
    scale = float(1.0 / (head_size**0.5))
395
    num_query_heads, num_kv_heads = num_heads
396
    qkv = torch.empty(num_tokens,
397
                      num_query_heads + 2 * num_kv_heads,
398
                      head_size,
399
                      dtype=dtype)
400
401
402
403
404
405
406
407
408
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
409
410
411
412
413
414
415
416
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
417
    )
418
    output = output.squeeze(0)
419

420
421
422
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
423
424
425
426
427
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
428
        scale,
429
430
        dtype,
    )
431
432
    atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
    rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
433
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)