test_attention.py 12.9 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
import torch
6
7
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
8

9
from vllm import _custom_ops as ops
10
from vllm.utils import get_max_shared_memory_bytes, is_hip
11

12
13
from .allclose_default import get_default_atol, get_default_rtol

14
15
16
17
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
18
19
20
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
21
PARTITION_SIZE = 512
22
23
24
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
          ] if not is_hip() else [torch.half, torch.bfloat16]
25
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
26
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
27
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
28
29
30

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
Joe's avatar
Joe committed
31
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
32
33
              ] if not is_hip() else [64, 80, 96, 112, 128]

34
BLOCK_SIZES = [16, 32]
35
USE_ALIBI = [False, True]
36
KV_CACHE_DTYPE = ["auto", "fp8"]
37
SEEDS = [0]
38
39
40
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
41

42
43
44
45
46
47
48
49

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
50
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
51
    if attn_mask is not None:
52
53
54
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
55
56
57
58
59
60
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
61
    num_queries_per_kv: int,
62
63
64
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
65
    seq_lens: torch.Tensor,
66
67
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
68
) -> None:
69
70
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
71
72
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
73
    num_seqs = query.shape[0]
74

75
76
    block_tables_lst = block_tables.cpu().tolist()
    seq_lens_lst = seq_lens.cpu().tolist()
77
    for i in range(num_seqs):
78
        q = query[i].unsqueeze(0)
79
80
        block_table = block_tables_lst[i]
        seq_len = int(seq_lens_lst[i])
81

82
83
        keys_lst: List[torch.Tensor] = []
        values_lst: List[torch.Tensor] = []
84
        for j in range(seq_len):
85
86
87
88
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
89
            k = k.reshape(num_kv_heads, head_size)
90
            keys_lst.append(k)
91
92

            v = value_cache[block_number, :, :, block_offset]
93
94
95
            values_lst.append(v)
        keys = torch.stack(keys_lst, dim=0)
        values = torch.stack(values_lst, dim=0)
96
97
98
99
100
101
102
103
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
104
105
            position_ids = torch.arange(seq_len).int()
            alibi_bias = (position_ids - seq_len + 1).float()
106
107
108
109
110
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
111
112
113
        output[i].copy_(out, non_blocking=True)


114
@pytest.mark.parametrize("version", ["v1", "v2"])
115
116
117
118
119
120
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
121
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
122
@pytest.mark.parametrize("seed", SEEDS)
123
@pytest.mark.parametrize("device", CUDA_DEVICES)
124
def test_paged_attention(
125
    kv_cache_factory,
126
    version: str,
127
128
    num_seqs: int,
    num_heads: Tuple[int, int],
129
    head_size: int,
130
    use_alibi: bool,
131
132
    block_size: int,
    dtype: torch.dtype,
133
    kv_cache_dtype: str,
134
    seed: int,
135
    device: str,
136
) -> None:
Joe's avatar
Joe committed
137
138
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
139
140
    random.seed(seed)
    torch.random.manual_seed(seed)
141
142
143
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
144
145
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
146
    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
147
148
149
150
151
152
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
153
        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
154

155
156
157
158
    seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
    seq_lens[-1] = MAX_SEQ_LEN
    max_seq_len = max(seq_lens)
    seq_lens = torch.tensor(seq_lens, dtype=torch.int)
159

160
    # Create the block tables.
161
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
162
    block_tables_lst: List[List[int]] = []
163
    for _ in range(num_seqs):
164
        block_table = [
165
            random.randint(0, NUM_BLOCKS - 1)
166
167
            for _ in range(max_num_blocks_per_seq)
        ]
168
169
170
        block_tables_lst.append(block_table)

    block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
171

172
173
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
174
175
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
176
                                                device)
177
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
178

179
    # Using default kv_scale
180
    k_scale = v_scale = 1.0
181

182
183
    # Call the paged attention kernel.
    output = torch.empty_like(query)
184
    if version == "v1":
185
        ops.paged_attention_v1(
186
187
188
189
            output,
            query,
            key_cache,
            value_cache,
190
            num_kv_heads,
191
192
            scale,
            block_tables,
193
            seq_lens,
194
            block_size,
195
            max_seq_len,
196
            alibi_slopes,
197
            kv_cache_dtype,
198
199
            k_scale,
            v_scale,
200
201
        )
    elif version == "v2":
202
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
203
204
205
206
207
208
209
210
211
212
213
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
        )
        max_logits = torch.empty_like(exp_sums)
214
        ops.paged_attention_v2(
215
216
217
218
219
220
221
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
222
            num_kv_heads,
223
224
            scale,
            block_tables,
225
            seq_lens,
226
            block_size,
227
            max_seq_len,
228
            alibi_slopes,
229
            kv_cache_dtype,
230
231
            k_scale,
            v_scale,
232
233
        )
    else:
234
        raise AssertionError(f"Unknown version: {version}")
235

236
    # Run the reference implementation.
237
    if kv_cache_dtype == "fp8":
238
239
240
241
242
243
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
244
                                            device=device)
245
        ops.convert_fp8(dequantized_key_cache, key_cache)
246
247
248
249
250
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
251
                                              device=device)
252
        ops.convert_fp8(dequantized_value_cache, value_cache)
253
254
        value_cache = dequantized_value_cache

255
256
257
258
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
259
        num_queries_per_kv,
260
261
262
        key_cache,
        value_cache,
        block_tables,
263
        seq_lens,
264
265
        scale,
        alibi_slopes,
266
    )
267
268
269
270

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
271
272
273
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5

274
275
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
276
277
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8":
278
279
        atol, rtol = 1e-2, 1e-5
    assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
280
281


282
283
284
285
286
287
288
289
290
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
291
    ref_outputs: List[torch.Tensor] = []
292
293
294
295
296
297
298
299
300
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
301
        attn_mask = attn_mask.to(dtype=dtype)
302
303
304
305
306
307
308
309
310

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
311
312

    return torch.cat(ref_outputs, dim=0)
313
314


315
# TODO(woosuk): Add tests for USE_ALIBI=True.
316
317
318
319
320
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
321
@pytest.mark.parametrize("device", CUDA_DEVICES)
322
@torch.inference_mode()
323
def test_multi_query_kv_attention(
324
    num_seqs: int,
325
    num_heads: Tuple[int, int],
326
327
    head_size: int,
    dtype: torch.dtype,
328
    seed: int,
329
    device: str,
330
) -> None:
331
332
    random.seed(seed)
    torch.random.manual_seed(seed)
333
334
335
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
336
337
338
339
340
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
341
342
    num_tokens = sum(seq_lens)

343
    scale = float(1.0 / (head_size**0.5))
344
    num_query_heads, num_kv_heads = num_heads
345
    qkv = torch.empty(num_tokens,
346
                      num_query_heads + 2 * num_kv_heads,
347
                      head_size,
348
                      dtype=dtype)
349
350
351
352
353
354
355
356
357
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
358
359
360
361
362
363
364
365
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
366
    )
367
    output = output.squeeze(0)
368

369
370
371
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
372
373
374
375
376
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
377
        scale,
378
379
        dtype,
    )
380
381
382
    atol = get_default_atol(output) if is_hip() else 1e-3
    rtol = get_default_rtol(output) if is_hip() else 1e-5
    assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)