"vscode:/vscode.git/clone" did not exist on "59c9b6ebeba79b2d744eec86734a7e13b03dcab7"
test_prefix_prefill.py 19.4 KB
Newer Older
1
import math
2
3
4
import random
import time

5
import pytest
6
7
import torch

8
from vllm.attention.ops.prefix_prefill import context_attention_fwd
9
from vllm.utils import  is_hip
10
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything
11

zhuwenwen's avatar
zhuwenwen committed
12
13
14
15
16
if not is_hip():
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
    from vllm.attention.backends.xformers import _make_alibi_bias

17
18
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
19
HEAD_SIZES = [128, 96, 24]
20
DTYPES = [torch.float16]
21
22
23
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
24
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
zhuwenwen's avatar
zhuwenwen committed
25
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] if not is_hip() else ["auto"]
26
27
28


@pytest.mark.parametrize("num_heads", NUM_HEADS)
29
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
30
31
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
32
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
33
@pytest.mark.parametrize("device", CUDA_DEVICES)
34
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
35
36
37
@torch.inference_mode()
def test_contexted_kv_attention(
    num_heads: int,
38
    num_queries_per_kv: int,
39
    head_size: int,
40
    sliding_window: int,
41
    dtype: torch.dtype,
42
    kv_cache_dtype: str,
43
    device: str,
44
) -> None:
45
    seed_everything(0)
46
    torch.set_default_device(device)
47

48
49
    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
50
51
52
53
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

54
55
56
57
58
59
    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    block_size = 32
    max_block_per_request = 64
60
    query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
61
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
62
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
63
    num_kv_heads = num_heads // num_queries_per_kv
64

65
    num_tokens = sum(query_lens)
66
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
67
    query.uniform_(-1e-3, 1e-3)
68
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
69

70
    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
71
72
73
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)

74
75
76
77
    if kv_cache_dtype == "auto":
        cache_dtype = dtype
    else:
        cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
78
79
    k_cache = torch.zeros(cache_size,
                          block_size,
80
                          num_kv_heads,
81
                          head_size,
82
                          dtype=cache_dtype)
83
84
    v_cache = torch.zeros(cache_size,
                          block_size,
85
                          num_kv_heads,
86
                          head_size,
87
                          dtype=cache_dtype)
88
89
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
90
    values = torch.arange(0, cache_size, dtype=torch.long)
91
92
93
    values = values[torch.randperm(cache_size)]
    block_table = values[:BS * max_block_per_request].view(
        BS, max_block_per_request)
94
95
    b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
96
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
97
                                            dtype=torch.long),
98
99
100
101
                               dim=0)
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
    b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
102
                                                dtype=torch.long),
103
104
                                   dim=0)
    for i in range(BS):
105
        for j in range(query_lens[i]):
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
                                            j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
                                              b_ctx_len[i] + j])
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
120
121
122
123
124
125
            k_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             key[start_loc:end_loc])
            v_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             value[start_loc:end_loc])
126
127
128
129
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
130
    k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
131
132
133
                           8).permute(0, 2, 3, 1, 4).contiguous()
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
134
    v_cache = v_cache.view(-1, block_size, num_kv_heads,
135
136
                           head_size).permute(0, 2, 3, 1).contiguous()

137
138
    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
139
140
141
142
    context_attention_fwd(query,
                          k,
                          v,
                          output,
143
                          kv_cache_dtype,
144
145
146
147
148
149
150
151
                          k_cache,
                          v_cache,
                          block_table,
                          b_start_loc,
                          b_seq_len,
                          b_ctx_len,
                          max_input_len,
                          sliding_window=sliding_window)
152
153
    torch.cuda.synchronize()
    start_time = time.time()
154
155
156
157
    context_attention_fwd(query,
                          k,
                          v,
                          output,
158
                          kv_cache_dtype,
159
160
161
162
163
164
165
166
                          k_cache,
                          v_cache,
                          block_table,
                          b_start_loc,
                          b_seq_len,
                          b_ctx_len,
                          max_input_len,
                          sliding_window=sliding_window)
167
168
169
170
    torch.cuda.synchronize()
    end_time = time.time()
    print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")

171
172
    if not is_hip():
        scale = float(1.0 / (head_size**0.5))
173

174
        attn_op = xops.fmha.cutlass.FwOp()
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        if num_kv_heads != num_heads:
            # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
            # project the key and value tensors to the desired number of
            # heads.
            #
            # see also: vllm/model_executor/layers/attention.py
            query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
                            query.shape[-1])
            key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
                                            num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                        None, :].expand(value.shape[0], num_kv_heads,
                                        num_queries_per_kv, value.shape[-1])
        query = query.unsqueeze(0)
        key = key.unsqueeze(0)
        value = value.unsqueeze(0)
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
            query_lens, seq_lens)
        if sliding_window > 0:
            attn_bias = attn_bias.make_local_attention_from_bottomright(
                sliding_window)
        output_ref = xops.memory_efficient_attention_forward(
            query,
            key,
            value,
            attn_bias=attn_bias,
            p=0.0,
            scale=scale,
            op=attn_op,
        )
        torch.cuda.synchronize()
        start_time = time.time()
        output_ref = xops.memory_efficient_attention_forward(
            query,
            key,
            value,
            attn_bias=attn_bias,
            p=0.0,
            scale=scale,
            op=attn_op,
        )
        torch.cuda.synchronize()
        end_time = time.time()
        print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
        output_ref = output_ref.reshape(output.shape)
222
223
        atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
        torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
224
225
226
227
228
229


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
230
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
231
232
233
234
235
236
237
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
    num_heads: int,
    num_queries_per_kv: int,
    head_size: int,
    dtype: torch.dtype,
238
    kv_cache_dtype: str,
239
240
    device: str,
) -> None:
241
    seed_everything(0)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    torch.set_default_device(device)

    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

    def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
        # Fork from: vllm/vllm/model_executor/models/bloom.py#L44
        closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
        base = torch.tensor(
            2**(-(2**-(math.log2(closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
        slopes = torch.pow(base, powers)

        if closest_power_of_2 != total_num_heads:
            extra_base = torch.tensor(
                2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
                dtype=torch.float32,
            )
            num_remaining_heads = min(closest_power_of_2,
                                      total_num_heads - closest_power_of_2)
            extra_powers = torch.arange(start=1,
                                        end=1 + 2 * num_remaining_heads,
                                        step=2,
                                        dtype=torch.int32)
            slopes = torch.cat(
                [slopes, torch.pow(extra_base, extra_powers)], dim=0)
        return slopes

    alibi_slopes = _get_alibi_slopes(num_heads).to(device)

    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    block_size = 32
    max_block_per_request = 64
    query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
    seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
    num_kv_heads = num_heads // num_queries_per_kv

    num_tokens = sum(query_lens)
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
    query.uniform_(-1e-3, 1e-3)
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)

    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)
296
297
298
299
    if kv_cache_dtype == "auto":
        cache_dtype = dtype
    else:
        cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
300
301
302
303
    k_cache = torch.zeros(cache_size,
                          block_size,
                          num_kv_heads,
                          head_size,
304
                          dtype=cache_dtype)
305
306
307
308
    v_cache = torch.zeros(cache_size,
                          block_size,
                          num_kv_heads,
                          head_size,
309
                          dtype=cache_dtype)
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
    k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
    values = torch.arange(0, cache_size, dtype=torch.long)
    values = values[torch.randperm(cache_size)]
    block_table = values[:BS * max_block_per_request].view(
        BS, max_block_per_request)
    b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
    b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
                                            dtype=torch.long),
                               dim=0)
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
    b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
                                                dtype=torch.long),
                                   dim=0)
    for i in range(BS):
        for j in range(query_lens[i]):
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
                                            j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
                                              b_ctx_len[i] + j])
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
            k_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             key[start_loc:end_loc])
            v_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             value[start_loc:end_loc])
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
    k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
                           8).permute(0, 2, 3, 1, 4).contiguous()
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
    v_cache = v_cache.view(-1, block_size, num_kv_heads,
                           head_size).permute(0, 2, 3, 1).contiguous()

    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
    context_attention_fwd(query,
                          k,
                          v,
                          output,
365
                          kv_cache_dtype,
366
367
368
369
370
371
372
373
374
375
376
377
378
379
                          k_cache,
                          v_cache,
                          block_table,
                          b_start_loc,
                          b_seq_len,
                          b_ctx_len,
                          max_input_len,
                          alibi_slopes=alibi_slopes)
    torch.cuda.synchronize()
    start_time = time.time()
    context_attention_fwd(query,
                          k,
                          v,
                          output,
380
                          kv_cache_dtype,
381
382
383
384
385
386
387
388
389
390
391
                          k_cache,
                          v_cache,
                          block_table,
                          b_start_loc,
                          b_seq_len,
                          b_ctx_len,
                          max_input_len,
                          alibi_slopes=alibi_slopes)
    torch.cuda.synchronize()
    end_time = time.time()
    print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    
    if not is_hip():
        scale = float(1.0 / (head_size**0.5))

        # NOTE(DefTruth): In order to reuse _make_alibi_bias function,
        # we have to pad query tensor before MQA/GQA expanding.
        if query.shape[0] != key.shape[0]:
            query_pad = torch.empty(sum(seq_lens),
                                    num_heads,
                                    head_size,
                                    dtype=dtype)
            query_pad.uniform_(-1e-3, 1e-3)
            seq_start = 0
            query_start = 0
            for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
                seq_end = seq_start + seq_len
                query_end = query_start + query_len
                query_pad[seq_start:seq_end, ...] = torch.cat([
                    torch.zeros(
                        seq_len - query_len, num_heads, head_size, dtype=dtype),
                    query[query_start:query_end, ...]
                ],
                                                            dim=0)
                seq_start += seq_len
                query_start += query_len
            query = query_pad

        if num_kv_heads != num_heads:
            # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
            # project the key and value tensors to the desired number of
            # heads.
            #
            # see also: vllm/model_executor/layers/attention.py
            query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
                            query.shape[-1])
            key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
                                            num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                        None, :].expand(value.shape[0], num_kv_heads,
                                        num_queries_per_kv, value.shape[-1])

        query = query.unsqueeze(0)
        key = key.unsqueeze(0)
        value = value.unsqueeze(0)
436

437
438
        attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
        output_ref = torch.empty_like(output)
439
440
        seq_start = 0
        query_start = 0
441
442
443
444
445
446
        start_time = time.time()
        # Attention with alibi slopes.
        # FIXME(DefTruth): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        # modified from: vllm/attention/backends/xformers.py#L343
447
448
449
        for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
            seq_end = seq_start + seq_len
            query_end = query_start + query_len
450
451
452
453
454
455
456
457
458
459
460
461
462
            out = xops.memory_efficient_attention_forward(query[:,
                                                                seq_start:seq_end],
                                                        key[:,
                                                            seq_start:seq_end],
                                                        value[:,
                                                                seq_start:seq_end],
                                                        attn_bias=attn_bias[i],
                                                        p=0.0,
                                                        scale=scale)
            out = out.view_as(query[:, seq_start:seq_end]).view(
                seq_len, num_heads, head_size)
            output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
                                                            ...])
463
464
            seq_start += seq_len
            query_start += query_len
zhuwenwen's avatar
zhuwenwen committed
465
466
467
468
469
        torch.cuda.synchronize()
        end_time = time.time()
        print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
        atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
        torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)