test_cache.py 30.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
import random

5
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
import torch

8
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS
9
from vllm import _custom_ops as ops
10
from vllm.platforms import current_platform
Woosuk Kwon's avatar
Woosuk Kwon committed
11

Vladimir's avatar
Vladimir committed
12
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
13
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
14
NUM_TOKENS = [42]  # Arbitrary values for testing
15
NUM_LAYERS = [1]  # Arbitrary values for testing
16
NUM_HEADS = [8]  # Arbitrary values for testing
17
HEAD_SIZES = [64, 80, 120, 256]
18
BLOCK_SIZES = [8, 16, 32]
19

20
21
22
23
24
25
26
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
QK_ROPE_HEAD_DIMS = [64]
NUM_TOKENS_MLA = [42]
BLOCK_SIZES_MLA = [16]
NUM_BLOCKS_MLA = [8]

27
28
29
30
# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

31
NUM_MAPPINGS = [256]  # Arbitrary values for testing
32
SEEDS = [0]
33
34
35
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
36
37

# We assume fp8 is always enabled for testing.
zhuwenwen's avatar
zhuwenwen committed
38
39
# KV_CACHE_DTYPE = ["auto", "fp8"] 
KV_CACHE_DTYPE = ["auto"] 
40
41
42
43
44
45
46
47
48
49


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
50
@pytest.mark.parametrize("device", CUDA_DEVICES)
51
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
52
@torch.inference_mode()
53
54
def test_copy_blocks(
    kv_cache_factory,
55
56
57
58
59
60
61
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
62
    seed: int,
63
    kv_cache_dtype: str,
64
    device: str,
65
) -> None:
Joe's avatar
Joe committed
66
67
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
68
    current_platform.seed_everything(seed)
69
    torch.set_default_device(device)
70
71
72
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
73
74
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
75
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
76
    block_mapping: list[tuple[int, int]] = []
77
    for i in range(num_mappings):
78
79
80
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
81
82
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
83
84
85
86

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
87
                                                head_size, kv_cache_dtype,
88
                                                dtype, seed, device)
89
90
91
92

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
93
94

    # Call the copy blocks kernel.
95
96
97
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)
98
99
100
101
102

    opcheck(torch.ops._C_cache_ops.copy_blocks,
            (key_caches, value_caches, block_mapping_tensor),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS,
            cond=(head_size == HEAD_SIZES[0]))
103
    ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
104

105
    # Run the reference implementation.
106
107
108
109
110
    for src, dst in block_mapping:
        for cloned_key_cache in cloned_key_caches:
            cloned_key_cache[dst].copy_(cloned_key_cache[src])
        for cloned_value_cache in cloned_value_caches:
            cloned_value_cache[dst].copy_(cloned_value_cache[src])
111
112
113

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
114
        torch.testing.assert_close(key_cache, cloned_key_cache)
115
116
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
117
        torch.testing.assert_close(value_cache, cloned_value_cache)
118
119


120
121
122
123
124
125
126
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
127
@pytest.mark.parametrize("device", CUDA_DEVICES)
128
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
129
@torch.inference_mode()
130
131
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
134
135
136
137
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
138
    seed: int,
139
    device: str,
140
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
) -> None:
Joe's avatar
Joe committed
142
143
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
144
    current_platform.seed_everything(seed)
145
    torch.set_default_device(device)
146
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
147
    num_slots = block_size * num_blocks
148
149
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
150
151

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
    _, key, value = qkv.unbind(dim=1)

154
155
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
156
157
158
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
159
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
160

161
162
163
164
    # Using default kv_scale
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)

165
    # Clone the KV caches.
166
167
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
168
        ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
169
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
170
        ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
171
172
173
174
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

175
    # Call the reshape_and_cache kernel.
176
177
178
179
    opcheck(torch.ops._C_cache_ops.reshape_and_cache,
            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
             k_scale, v_scale),
            cond=(head_size == HEAD_SIZES[0]))
180
    ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
181
                          kv_cache_dtype, k_scale, v_scale)
182
183
184

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
185
        ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
186
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
187
        ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
Woosuk Kwon's avatar
Woosuk Kwon committed
188

189
190
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
191
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
192
    block_indicies_lst = block_indicies.cpu().tolist()
193
    block_offsets = slot_mapping % block_size
194
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
195
    for i in range(num_tokens):
196
197
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
198
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
199
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
200

201
    if kv_cache_dtype == "fp8":
202
203
204
205
206
207
208
209
        torch.testing.assert_close(result_key_cache,
                                   cloned_key_cache,
                                   atol=0.001,
                                   rtol=0.1)
        torch.testing.assert_close(result_value_cache,
                                   cloned_value_cache,
                                   atol=0.001,
                                   rtol=0.1)
210
    else:
211
212
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
213
        
Vladimir's avatar
Vladimir committed
214

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
237
    current_platform.seed_everything(seed)
238
    torch.set_default_device(device)
239
240
241

    # Create a random slot mapping.
    num_slots = block_size * num_blocks
242
243
244
245
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263

    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device=device)
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
264
        device=device,
265
    )
266
267
268
269
    key_cache, value_cache = key_caches[0].contiguous(
    ), value_caches[0].contiguous()
    del key_caches
    del value_caches
270

271
272
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)
273

274
    # Clone the KV caches.
275
276
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
277
278
        ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
                        kv_cache_dtype)
279
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
280
        ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
281
                        kv_cache_dtype)
282
283
284
285
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

286
    # Call the reshape_and_cache kernel.
287
288
289
290
    opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
             k_scale, v_scale),
            cond=(head_size == HEAD_SIZES[0]))
291
    ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
292
                                slot_mapping, kv_cache_dtype, k_scale, v_scale)
293

294
295
    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
296
297
        ops.convert_fp8(result_key_cache,
                        key_cache,
298
                        k_scale.item(),
299
                        kv_dtype=kv_cache_dtype)
300
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
301
302
        ops.convert_fp8(result_value_cache,
                        value_cache,
303
                        v_scale.item(),
304
                        kv_dtype=kv_cache_dtype)
305
306

    # Run the reference implementation.
307
308
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indicies_lst = block_indicies.cpu().tolist()
309
    block_offsets = slot_mapping % block_size
310
    block_offsets_lst = block_offsets.cpu().tolist()
311
    for i in range(num_tokens):
312
313
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
314
315
316
        cloned_key_cache[block_idx, block_offset, :, :] = key[i]
        cloned_value_cache[block_idx, block_offset, :, :] = value[i]

317
    if kv_cache_dtype == "fp8":
318
319
320
321
322
323
324
325
        torch.testing.assert_close(result_key_cache,
                                   cloned_key_cache,
                                   atol=0.001,
                                   rtol=0.1)
        torch.testing.assert_close(result_value_cache,
                                   cloned_value_cache,
                                   atol=0.001,
                                   rtol=0.1)
326
    else:
327
328
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
zhuwenwen's avatar
zhuwenwen committed
329
        
330

Vladimir's avatar
Vladimir committed
331
332
333
334
335
336
337
338
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
339
@pytest.mark.parametrize("device", CUDA_DEVICES)
340
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
341
342
343
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
344
    direction: tuple[str, str],
Vladimir's avatar
Vladimir committed
345
346
347
348
349
350
351
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
352
    device: str,
353
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
354
) -> None:
355
356
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
357
358
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
359

360
    current_platform.seed_everything(seed)
361
362
363

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
364
365
366
367
368
369
370
371
372

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

373
374
375
376
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)
Vladimir's avatar
Vladimir committed
377
378
379

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
380
381
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
382
383
384

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
385
386
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
387
388
389
390
391

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
392
393
394
395
396
397
398
399
    do_opcheck = (head_size == HEAD_SIZES[0])
    opcheck(torch.ops._C_cache_ops.swap_blocks,
            (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
            cond=do_opcheck)
    opcheck(torch.ops._C_cache_ops.swap_blocks,
            (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
            cond=do_opcheck)

400
401
402
403
    ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                    block_mapping_tensor)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                    block_mapping_tensor)
Vladimir's avatar
Vladimir committed
404

405
    for src, dst in block_mapping:
406
407
408
409
        torch.testing.assert_close(src_key_caches_clone[src].cpu(),
                                   dist_key_caches[0][dst].cpu())
        torch.testing.assert_close(src_value_caches_clone[src].cpu(),
                                   dist_value_caches[0][dst].cpu())
410
411


zhuwenwen's avatar
zhuwenwen committed
412
@pytest.mark.skipif(current_platform.is_rocm(),
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
                    reason="FP8 is not supported on ROCm.")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_e4m3_conversion(
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
431
    current_platform.seed_everything(seed)
432
433
434
435
436
437
438
439
440
441
442
443
444
445

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
    ops.convert_fp8(cache_fp8, cache)

    converted_cache = torch.empty_like(cache)
    ops.convert_fp8(converted_cache, cache_fp8)

    torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
446
447
448
449
450
451
452
453
454
455
456


def _create_mla_cache(
    num_blocks: int,
    block_size: int,
    entry_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
) -> torch.Tensor:
    cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
457
458
459
460
461
    return torch.zeros(num_blocks,
                       block_size,
                       entry_size,
                       dtype=cache_dtype,
                       device=device)
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513


def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
    rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype

    vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
    if kv_cache_dtype == "fp8":
        temp = torch.zeros_like(cache)
        ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
        vals = temp
    cache.copy_(vals)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_concat_and_cache_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
    k_pe = torch.randn(num_tokens,
                       qk_rope_head_dim,
                       dtype=dtype,
                       device=device)
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
    kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
514
                                 kv_cache_dtype, device)
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
        ops.convert_fp8(ref_kv_cache,
                        ref_temp,
                        scale.item(),
                        kv_dtype=kv_cache_dtype)
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
                             kv_cache_dtype, scale)

    if kv_cache_dtype == "fp8":
        result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
        ops.convert_fp8(result_temp,
                        kv_cache.contiguous(),
                        scale.item(),
                        kv_dtype=kv_cache_dtype)
        expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
        ops.convert_fp8(expected_temp,
                        ref_kv_cache,
                        scale.item(),
                        kv_dtype=kv_cache_dtype)
        torch.testing.assert_close(result_temp,
                                   expected_temp,
                                   atol=0.001,
                                   rtol=0.1)
    else:
        torch.testing.assert_close(kv_cache, ref_kv_cache)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_copy_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    num_layers: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    entry_size = kv_lora_rank + qk_rope_head_dim

    kv_caches = []
    for _ in range(num_layers):
        kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
590
                                     kv_cache_dtype, device)
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
        kv_caches.append(kv_cache)

    ref_caches = [kv_cache.clone() for kv_cache in kv_caches]

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining, 2 * num_mappings)
    block_mapping = []
    for i in range(num_mappings):
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)

    for src, dst in block_mapping:
        for ref_cache in ref_caches:
            ref_cache[dst].copy_(ref_cache[src])

    opcheck(
        torch.ops._C_cache_ops.copy_blocks_mla,
        (kv_caches, block_mapping_tensor),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )
    ops.copy_blocks_mla(kv_caches, block_mapping_tensor)

    for kv_cache, ref_cache in zip(kv_caches, ref_caches):
        torch.testing.assert_close(kv_cache, ref_cache)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_swap_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    entry_size = kv_lora_rank + qk_rope_head_dim

    src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
651
                                  kv_cache_dtype, device)
652
    dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
653
                                  kv_cache_dtype, device)
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682

    _fill_mla_cache(src_cache, kv_cache_dtype)
    _fill_mla_cache(dst_cache, kv_cache_dtype)

    src_cache_clone = src_cache.clone()

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining_blocks, num_mappings)
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)

    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
        (src_cache, dst_cache, block_mapping_tensor),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)

    for src, dst in block_mapping:
        torch.testing.assert_close(
            src_cache_clone[src].cpu(),
            dst_cache[dst].cpu(),
            msg=f"Block {src} from src should have been swapped to block "
            f"{dst} in dst_cache.")
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697


@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
                         ["auto"])  # You can also test "fp8" if needed.
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
                          num_blocks, max_seq_len, batch_size, dtype,
698
                          kv_cache_dtype, device):
699
700
    entry_size = kv_lora_rank + qk_rope_head_dim
    src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
701
                                  kv_cache_dtype, device)
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

    seq_len_tensor = torch.randint(0,
                                   max_seq_len + 1, (batch_size, ),
                                   device=device)

    total_tokens = seq_len_tensor.sum()
    cu_seq_lens = torch.empty((batch_size + 1),
                              dtype=torch.int32,
                              device=device)
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
    block_table = torch.empty((batch_size, num_blocks),
                              dtype=torch.int32,
                              device=device)

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

    dst = torch.zeros((total_tokens, entry_size),
                      dtype=src_cache.dtype,
                      device=device)

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
            gathered_rows.append(src_cache[blocks[i]])
        remaining = s - (tot - 1) * block_size
        gathered_rows.append(src_cache[blocks[-1], :remaining, :])

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
        torch.ops._C_cache_ops.gather_cache,
        (src_cache, dst, block_table, cu_seq_lens, batch_size, None),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
    torch.testing.assert_close(dst, expected)
Thien Tran's avatar
Thien Tran committed
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
) -> None:
    device = "cpu"
    kv_cache_dtype = "auto"
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
    k_pe = torch.randn(num_tokens,
                       qk_rope_head_dim,
                       dtype=dtype,
                       device=device)
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
    kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
                                 kv_cache_dtype, device)
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
        ops.convert_fp8(ref_kv_cache,
                        ref_temp,
                        scale.item(),
                        kv_dtype=kv_cache_dtype)
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
                             kv_cache_dtype, scale)
    torch.testing.assert_close(kv_cache, ref_kv_cache)