test_attention.py 14.7 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
6
import torch

7
from tests.kernels.utils import opcheck
8
from vllm import _custom_ops as ops
9
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
10

11
12
from .allclose_default import get_default_atol, get_default_rtol

13
14
15
16
if not is_hip():
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

17
18
19
20
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
21
22
23
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
24
PARTITION_SIZE = 512
25
26
27
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
          ] if not is_hip() else [torch.half, torch.bfloat16]
28
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
29
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
30
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
31
32
33

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
34
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
35

36
BLOCK_SIZES = [16, 32]
37
USE_ALIBI = [False, True]
38
KV_CACHE_DTYPE = ["auto", "fp8"]
39
SEEDS = [0]
40
41
42
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
43

44
45
46
47
48
49
50
51

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
52
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
53
    if attn_mask is not None:
54
55
56
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
57
58
59
60
61
62
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
63
    num_queries_per_kv: int,
64
65
66
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
67
    seq_lens: torch.Tensor,
68
69
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
70
) -> None:
71
72
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
73
74
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
75
    num_seqs = query.shape[0]
76

77
78
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
79
    for i in range(num_seqs):
80
        q = query[i].unsqueeze(0)
81
82
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
83

84
85
        keys_lst: List[torch.Tensor] = []
        values_lst: List[torch.Tensor] = []
86
        for j in range(seq_len):
87
88
89
90
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
91
            k = k.reshape(num_kv_heads, head_size)
92
            keys_lst.append(k)
93
94

            v = value_cache[block_number, :, :, block_offset]
95
96
97
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
98
99
100
101
102
103
104
105
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
106
107
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
108
109
110
111
112
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
113
114
115
        output[i].copy_(out, non_blocking=True)


116
117
@pytest.mark.parametrize(
    "version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
118
119
120
121
122
123
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
124
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
125
@pytest.mark.parametrize("seed", SEEDS)
126
@pytest.mark.parametrize("device", CUDA_DEVICES)
127
def test_paged_attention(
128
    kv_cache_factory,
129
    version: str,
130
131
    num_seqs: int,
    num_heads: Tuple[int, int],
132
    head_size: int,
133
    use_alibi: bool,
134
135
    block_size: int,
    dtype: torch.dtype,
136
    kv_cache_dtype: str,
137
    seed: int,
138
    device: str,
139
) -> None:
140
141
    if ((kv_cache_dtype == "fp8" and head_size % 16)
            or (version == "rocm" and head_size not in (64, 128))):
Joe's avatar
Joe committed
142
        pytest.skip()
143
144

    seed_everything(seed)
145
    torch.set_default_device(device)
146
147
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
148
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
149
150
151
152
153
154
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
155
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
156

157
158
159
160
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
161

162
    # Create the block tables.
163
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
164
    block_tables_lst: List[List[int]] = []
165
    for _ in range(num_seqs):
166
        block_table = [
167
            random.randint(0, NUM_BLOCKS - 1)
168
169
            for _ in range(max_num_blocks_per_seq)
        ]
170
171
172
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
173

174
175
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
176
177
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
178
                                                device)
179
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
180

181
    # Using default kv_scale
182
    k_scale = v_scale = 1.0
183

184
185
    # Call the paged attention kernel.
    output = torch.empty_like(query)
186
    if version == "v1":
187
        ops.paged_attention_v1(
188
189
190
191
            output,
            query,
            key_cache,
            value_cache,
192
            num_kv_heads,
193
194
            scale,
            block_tables,
195
            seq_lens,
196
            block_size,
197
            max_seq_len,
198
            alibi_slopes,
199
            kv_cache_dtype,
200
201
            k_scale,
            v_scale,
202
        )
203
204
205
206
207
208
209

        opcheck(torch.ops._C.paged_attention_v1,
                (output, query, key_cache, value_cache, num_kv_heads, scale,
                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
                cond=(head_size == HEAD_SIZES[0]))

210
    elif version in ("v2", "rocm"):
211
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
212
213
214
215
216
217
218
219
220
221
222
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        if version == "v2":
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._C.paged_attention_v2,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
                    cond=(head_size == HEAD_SIZES[0]))

        else:
            ops.paged_attention_rocm(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
                seq_lens,
                block_size,
                max_seq_len,
                alibi_slopes,
                kv_cache_dtype,
                k_scale,
                v_scale,
            )

            opcheck(torch.ops._rocm_C.paged_attention,
                    (output, exp_sums, max_logits, tmp_output, query,
                     key_cache, value_cache, num_kv_heads, scale, block_tables,
                     seq_lens, block_size, max_seq_len, alibi_slopes,
                     kv_cache_dtype, k_scale, v_scale),
                    cond=(head_size == HEAD_SIZES[0]))
278

279
    else:
280
        raise AssertionError(f"Unknown version: {version}")
281

282
    # Run the reference implementation.
283
    if kv_cache_dtype == "fp8":
284
285
286
287
288
289
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
290
                                            device=device)
291
        ops.convert_fp8(dequantized_key_cache, key_cache)
292
293
294
295
296
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
297
                                              device=device)
298
        ops.convert_fp8(dequantized_value_cache, value_cache)
299
300
        value_cache = dequantized_value_cache

301
302
303
304
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
305
        num_queries_per_kv,
306
307
308
        key_cache,
        value_cache,
        block_tables,
309
        seq_lens,
310
311
        scale,
        alibi_slopes,
312
    )
313
314
315
316

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
317
318
319
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5

320
321
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
322
323
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
324
        atol, rtol = 1e-2, 1e-5
325
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
326
327


328
329
330
331
332
333
334
335
336
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
337
    ref_outputs: List[torch.Tensor] = []
338
339
340
341
342
343
344
345
346
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
347
        attn_mask = attn_mask.to(dtype=dtype)
348
349
350
351
352
353
354
355
356

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
357
358

    return torch.cat(ref_outputs, dim=0)
359
360


361
# TODO(woosuk): Add tests for USE_ALIBI=True.
362
363
364
365
366
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
367
@pytest.mark.parametrize("device", CUDA_DEVICES)
368
369
@pytest.mark.skipif(is_hip(),
                    reason="Xformers backend is not supported on ROCm.")
370
@torch.inference_mode()
371
def test_multi_query_kv_attention(
372
    num_seqs: int,
373
    num_heads: Tuple[int, int],
374
375
    head_size: int,
    dtype: torch.dtype,
376
    seed: int,
377
    device: str,
378
) -> None:
379
    seed_everything(seed)
380
    torch.set_default_device(device)
381
382
383
384
385
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
386
387
    num_tokens = sum(seq_lens)

388
    scale = float(1.0 / (head_size**0.5))
389
    num_query_heads, num_kv_heads = num_heads
390
    qkv = torch.empty(num_tokens,
391
                      num_query_heads + 2 * num_kv_heads,
392
                      head_size,
393
                      dtype=dtype)
394
395
396
397
398
399
400
401
402
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
403
404
405
406
407
408
409
410
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
411
    )
412
    output = output.squeeze(0)
413

414
415
416
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
417
418
419
420
421
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
422
        scale,
423
424
        dtype,
    )
425
426
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5
427
    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)