test_attention.py 12.3 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
import torch
6
7
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
8

9
from vllm._C import ops, cache_ops
10
from vllm.utils import get_max_shared_memory_bytes
11

12
13
14
15
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
16
17
18
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321  # Arbitrary values for testing
19
PARTITION_SIZE = 512
20
21
22

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
23
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
24
25
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
26
BLOCK_SIZES = [16, 32]
27
USE_ALIBI = [False, True]
28
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
29
SEEDS = [0]
30
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
31

32
33
34
35
36
37
38
39

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
40
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
41
    if attn_mask is not None:
42
43
44
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
45
46
47
48
49
50
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
51
    num_queries_per_kv: int,
52
53
54
55
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
56
57
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
58
) -> None:
59
60
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
61
62
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
63
    num_seqs = query.shape[0]
64

65
66
67
    block_tables = block_tables.cpu().tolist()
    context_lens = context_lens.cpu().tolist()
    for i in range(num_seqs):
68
69
70
71
72
73
74
75
76
77
78
        q = query[i].unsqueeze(0)
        block_table = block_tables[i]
        context_len = int(context_lens[i])

        keys = []
        values = []
        for j in range(context_len):
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
79
            k = k.reshape(num_kv_heads, head_size)
80
81
82
83
84
85
            keys.append(k)

            v = value_cache[block_number, :, :, block_offset]
            values.append(v)
        keys = torch.stack(keys, dim=0)
        values = torch.stack(values, dim=0)
86
87
88
89
90
91
92
93
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
94
            position_ids = torch.arange(context_len, device=query.device).int()
95
            alibi_bias = (position_ids - context_len + 1).float()
96
97
98
99
100
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
101
102
103
        output[i].copy_(out, non_blocking=True)


104
@pytest.mark.parametrize("version", ["v1", "v2"])
105
106
107
108
109
110
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
111
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
112
@pytest.mark.parametrize("seed", SEEDS)
113
@pytest.mark.parametrize("device", DEVICES)
114
def test_paged_attention(
115
    kv_cache_factory,
116
    version: str,
117
118
    num_seqs: int,
    num_heads: Tuple[int, int],
119
    head_size: int,
120
    use_alibi: bool,
121
122
    block_size: int,
    dtype: torch.dtype,
123
    kv_cache_dtype: str,
124
    seed: int,
125
    device: int,
126
) -> None:
127
128
129
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
130
    gpu_id = f"cuda:{device}"
131
132
133
134
135
136
    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
137
                        device=gpu_id)
138
139
140
141
142
143
144
145
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
146
                                   device=gpu_id)
147
148

    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
149
    context_lens[-1] = MAX_SEQ_LEN
150
    max_context_len = max(context_lens)
151
    context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
152

153
    # Create the block tables.
154
155
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
156
    for _ in range(num_seqs):
157
        block_table = [
158
            random.randint(0, NUM_BLOCKS - 1)
159
160
161
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
162
    block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
163

164
165
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
166
167
168
                                                num_kv_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                gpu_id)
169
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
170

171
172
    # Call the paged attention kernel.
    output = torch.empty_like(query)
173
    if version == "v1":
174
        ops.paged_attention_v1(
175
176
177
178
            output,
            query,
            key_cache,
            value_cache,
179
            num_kv_heads,
180
181
182
183
184
185
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
186
            kv_cache_dtype,
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        )
    elif version == "v2":
        num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
                          PARTITION_SIZE)
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
204
        ops.paged_attention_v2(
205
206
207
208
209
210
211
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
212
            num_kv_heads,
213
214
215
216
217
218
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
219
            kv_cache_dtype,
220
221
        )
    else:
222
        raise AssertionError(f"Unknown version: {version}")
223

224
    # Run the reference implementation.
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    if kv_cache_dtype == "fp8_e5m2":
        # Convert cache data back to dtype.
        x = 16 // torch.tensor([], dtype=dtype).element_size()
        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
                           block_size, x)
        dequantized_key_cache = torch.empty(size=key_cache_shape,
                                            dtype=dtype,
                                            device=gpu_id)
        cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
        key_cache = dequantized_key_cache

        value_cache_shape = value_cache.shape
        dequantized_value_cache = torch.empty(size=value_cache_shape,
                                              dtype=dtype,
                                              device=gpu_id)
        cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
        value_cache = dequantized_value_cache

243
244
245
246
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
247
        num_queries_per_kv,
248
249
250
251
        key_cache,
        value_cache,
        block_tables,
        context_lens,
252
253
        scale,
        alibi_slopes,
254
    )
255
256
257
258

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
259
260
261
262
263
264
    # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
    # so we use a relaxed tolerance for the test.
    atol, rtol = 1e-3, 1e-5
    if kv_cache_dtype == "fp8_e5m2":
        atol, rtol = 1e-2, 1e-5
    assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
265
266


267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
    ref_outputs = []
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
286
        attn_mask = attn_mask.to(dtype=dtype, device=query.device)
287
288
289
290
291
292
293
294
295
296
297
298
299

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
    ref_output = torch.cat(ref_outputs, dim=0)
    return ref_output


300
# TODO(woosuk): Add tests for USE_ALIBI=True.
301
302
303
304
305
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
306
@pytest.mark.parametrize("device", DEVICES)
307
@torch.inference_mode()
308
def test_multi_query_kv_attention(
309
    num_seqs: int,
310
    num_heads: Tuple[int, int],
311
312
    head_size: int,
    dtype: torch.dtype,
313
    seed: int,
314
    device: int,
315
) -> None:
316
317
318
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
319
    gpu_id = f"cuda:{device}"
320
321
322
323
324
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
325
326
    num_tokens = sum(seq_lens)

327
    scale = float(1.0 / (head_size**0.5))
328
    num_query_heads, num_kv_heads = num_heads
329
    qkv = torch.empty(num_tokens,
330
                      num_query_heads + 2 * num_kv_heads,
331
332
                      head_size,
                      dtype=dtype,
333
                      device=gpu_id)
334
335
336
337
338
339
340
341
342
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
343
344
345
346
347
348
349
350
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
351
    )
352
    output = output.squeeze(0)
353

354
355
356
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
357
358
359
360
361
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
362
        scale,
363
364
        dtype,
    )
365
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)