test_cache.py 38.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import random

6
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch

9
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
from vllm import _custom_ops as ops
11
from vllm.platforms import current_platform
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
14
DTYPES = [torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
15
NUM_TOKENS = [42]  # Arbitrary values for testing
16
NUM_LAYERS = [1]  # Arbitrary values for testing
17
NUM_HEADS = [8]  # Arbitrary values for testing
18
HEAD_SIZES = [64, 80, 256]
19
BLOCK_SIZES = [8, 16, 32]
20
CACHE_LAYOUTS = ["NHD", "HND"]
21

22
23
24
25
26
27
28
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
QK_ROPE_HEAD_DIMS = [64]
NUM_TOKENS_MLA = [42]
BLOCK_SIZES_MLA = [16]
NUM_BLOCKS_MLA = [8]

29
30
31
32
# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

33
NUM_MAPPINGS = [256]  # Arbitrary values for testing
34
SEEDS = [0]
35
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
36
37

# We assume fp8 is always enabled for testing.
38
KV_CACHE_DTYPE = ["auto", "fp8"]
39

40
41
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]

42
43
44
45
46
47
48
49
50

@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
51
@pytest.mark.parametrize("device", CUDA_DEVICES)
52
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
53
@torch.inference_mode()
54
55
def test_copy_blocks(
    kv_cache_factory,
56
57
58
59
60
61
62
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
63
    seed: int,
64
    kv_cache_dtype: str,
65
    device: str,
66
) -> None:
Joe's avatar
Joe committed
67
68
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
69
    current_platform.seed_everything(seed)
70
    torch.set_default_device(device)
71
    torch.cuda.set_device(device)
72
73
74
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
75
    src_blocks = random.sample(range(num_blocks), num_mappings)
76
77
    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining_blocks, 2 * num_mappings)
78
    block_mapping: list[tuple[int, int]] = []
79
    for i in range(num_mappings):
80
81
82
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
83
84
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
85
86

    # Create the KV caches.
87
88
89
90
91
92
93
94
95
96
97
    key_caches, value_caches = kv_cache_factory(
        num_blocks,
        block_size,
        num_layers,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        device,
    )
98
99
100
101

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
102
103

    # Call the copy blocks kernel.
104
105
106
107
108
109
110
111
112
113
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device=device
    ).view(-1, 2)

    opcheck(
        torch.ops._C_cache_ops.copy_blocks,
        (key_caches, value_caches, block_mapping_tensor),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
        cond=(head_size == HEAD_SIZES[0]),
    )
114
    ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
115

116
    # Run the reference implementation.
117
118
119
120
121
    for src, dst in block_mapping:
        for cloned_key_cache in cloned_key_caches:
            cloned_key_cache[dst].copy_(cloned_key_cache[src])
        for cloned_value_cache in cloned_value_caches:
            cloned_value_cache[dst].copy_(cloned_value_cache[src])
122
123
124

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
125
        torch.testing.assert_close(key_cache, cloned_key_cache)
126
    for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
127
        torch.testing.assert_close(value_cache, cloned_value_cache)
128
129


130
131
132
133
134
135
136
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
137
@pytest.mark.parametrize("device", CUDA_DEVICES)
138
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
139
@torch.inference_mode()
140
141
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145
146
147
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
148
    seed: int,
149
    device: str,
150
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
151
) -> None:
Joe's avatar
Joe committed
152
153
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
154
    current_platform.seed_everything(seed)
155
    torch.set_default_device(device)
156
    torch.cuda.set_device(device)
157
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
158
    num_slots = block_size * num_blocks
159
160
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
161
162

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
    _, key, value = qkv.unbind(dim=1)

165
    # Create the KV caches.
166
167
168
169
170
171
172
173
174
175
176
    key_caches, value_caches = kv_cache_factory(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        device,
    )
177
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
178

179
180
181
182
    # Using default kv_scale
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)

183
    # Clone the KV caches.
184
185
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
186
        ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
187
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
188
        ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
189
190
191
192
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

193
    # Call the reshape_and_cache kernel.
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    opcheck(
        torch.ops._C_cache_ops.reshape_and_cache,
        (
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        ),
        cond=(head_size == HEAD_SIZES[0]),
    )
    ops.reshape_and_cache(
        key,
        value,
        key_cache,
        value_cache,
        slot_mapping,
        kv_cache_dtype,
        k_scale,
        v_scale,
    )
218
219
220

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
221
        ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
222
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
223
        ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
Woosuk Kwon's avatar
Woosuk Kwon committed
224

225
226
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
227
228
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
229
    block_offsets = slot_mapping % block_size
230
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
231
    for i in range(num_tokens):
232
        block_idx = block_indices_lst[i]
233
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
234
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
235
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
236

237
    if kv_cache_dtype == "fp8":
238
239
240
241
242
243
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
244
    else:
245
246
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
247
248


249
250
251
252
253
254
255
256
257
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
258
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
259
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
260
261
262
263
264
265
266
267
268
269
270
271
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
272
    kv_cache_layout: str,
273
    implementation: str,
274
) -> None:
275
    current_platform.seed_everything(seed)
276
    torch.set_default_device(device)
277
    torch.cuda.set_device(device)
278
279
280
    assert implementation in ["cuda", "triton"]
    if implementation == "triton" and kv_cache_layout == "HND":
        pytest.skip("Triton implementation only supports NHD layout.")
281

282
283
284
285
    # fp8 conversion requires continugous memory buffer. Reduce the number of
    # blocks and tokens to consume less memory.
    num_tokens = num_tokens // 2
    num_blocks = num_blocks // 2
286
287
    # Create a random slot mapping.
    num_slots = block_size * num_blocks
288
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
289
290
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
291
292
293
294
295
296
297
298
299
300
301
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
302
        device=device,
303
        cache_layout=kv_cache_layout,
304
    )
305
    key_cache, value_cache = key_caches[0], value_caches[0]
306
307
    del key_caches
    del value_caches
308

309
310
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)
311

312
313
314
315
316
317
318
    def permute_and_compact(x):
        y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
        return y.contiguous()

    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)

319
    # Clone the KV caches.
320
    if kv_cache_dtype == "fp8":
321
322
323
324
325
326
327
328
        cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
        ops.convert_fp8(
            cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype
        )
        cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
        ops.convert_fp8(
            cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype
        )
329
    else:
330
331
        cloned_key_cache = key_cache_compact.clone()
        cloned_value_cache = value_cache_compact.clone()
332
    # Call the reshape_and_cache kernel.
333
    if implementation == "cuda":
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        opcheck(
            torch.ops._C_cache_ops.reshape_and_cache_flash,
            (
                key,
                value,
                key_cache,
                value_cache,
                slot_mapping,
                kv_cache_dtype,
                k_scale,
                v_scale,
            ),
            cond=(head_size == HEAD_SIZES[0]),
        )
        ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
358
359
    elif implementation == "triton":
        from vllm.attention.ops.triton_reshape_and_cache_flash import (
360
361
362
363
364
365
366
367
368
369
370
371
372
            triton_reshape_and_cache_flash,
        )

        triton_reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
373
374
    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)
375
376

    if kv_cache_dtype == "fp8":
377
378
379
380
381
382
383
384
385
386
387
        result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
        ops.convert_fp8(
            result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype
        )
        result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
        ops.convert_fp8(
            result_value_cache,
            value_cache_compact,
            v_scale.item(),
            kv_dtype=kv_cache_dtype,
        )
388
389

    # Run the reference implementation.
390
391
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
392
    block_offsets = slot_mapping % block_size
393
    block_offsets_lst = block_offsets.cpu().tolist()
394
    for i in range(num_tokens):
395
        block_idx = block_indices_lst[i]
396
        block_offset = block_offsets_lst[i]
397
398
399
400
401
402
        if kv_cache_layout == "NHD":
            cloned_key_cache[block_idx, block_offset, :, :] = key[i]
            cloned_value_cache[block_idx, block_offset, :, :] = value[i]
        else:
            cloned_key_cache[block_idx, :, block_offset, :] = key[i]
            cloned_value_cache[block_idx, :, block_offset, :] = value[i]
403

404
    if kv_cache_dtype == "fp8":
405
406
407
408
409
410
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
411
    else:
412
413
        torch.testing.assert_close(key_cache_compact, cloned_key_cache)
        torch.testing.assert_close(value_cache_compact, cloned_value_cache)
414
415


Vladimir's avatar
Vladimir committed
416
417
418
419
420
421
422
423
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
424
@pytest.mark.parametrize("device", CUDA_DEVICES)
425
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
426
427
428
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
429
    direction: tuple[str, str],
Vladimir's avatar
Vladimir committed
430
431
432
433
434
435
436
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
437
    device: str,
438
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
439
) -> None:
440
441
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
442
443
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
444

445
    current_platform.seed_everything(seed)
446

447
448
    src_device = device if direction[0] == "cuda" else "cpu"
    dst_device = device if direction[1] == "cuda" else "cpu"
Vladimir's avatar
Vladimir committed
449
450
451
452
453
454
455
456
457

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

458
    block_mapping = list(zip(src_blocks, dst_blocks))
459
460
461
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
Vladimir's avatar
Vladimir committed
462
463
464

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
465
466
467
468
469
470
471
472
473
474
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        src_device,
    )
Vladimir's avatar
Vladimir committed
475
476
477

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
478
479
480
481
482
483
484
485
486
487
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        dst_device,
    )
Vladimir's avatar
Vladimir committed
488
489
490
491
492

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    do_opcheck = head_size == HEAD_SIZES[0]
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
        (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
        cond=do_opcheck,
    )
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
        (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
        cond=do_opcheck,
    )

    ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
Vladimir's avatar
Vladimir committed
507

508
    for src, dst in block_mapping:
509
510
511
512
513
514
        torch.testing.assert_close(
            src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
        )
        torch.testing.assert_close(
            src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
        )
515
516
517
518
519
520
521
522
523
524


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
525
def test_fp8_e4m3_conversion(
526
527
528
529
530
531
532
533
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
534
    current_platform.seed_everything(seed)
535
536
537
538
539
540
541
542

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
543
    ops.convert_fp8(cache_fp8, cache)
544
545

    converted_cache = torch.empty_like(cache)
546
    ops.convert_fp8(converted_cache, cache_fp8)
547

548
    torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
549
550
551
552
553
554
555
556
557
558
559


def _create_mla_cache(
    num_blocks: int,
    block_size: int,
    entry_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
) -> torch.Tensor:
    cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
560
561
562
    return torch.zeros(
        num_blocks, block_size, entry_size, dtype=cache_dtype, device=device
    )
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598


def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
    rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype

    vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
    if kv_cache_dtype == "fp8":
        temp = torch.zeros_like(cache)
        ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
        vals = temp
    cache.copy_(vals)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_concat_and_cache_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)
599
    torch.cuda.set_device(device)
600
601
602

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
603
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
604
605

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
606
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
607
608
609
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
610
611
612
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
613
614
615
616
617
618
619
620
621
622
623
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
624
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
625
626
627
628
629
630
631
632
633
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

634
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
635
636
637

    if kv_cache_dtype == "fp8":
        result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
638
639
640
        ops.convert_fp8(
            result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype
        )
641
        expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
642
643
644
645
        ops.convert_fp8(
            expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype
        )
        torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
646
647
648
649
    else:
        torch.testing.assert_close(kv_cache, ref_kv_cache)


650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
669
670
    if current_platform.is_rocm():
        pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
671
672
673
674
675
    if dtype.itemsize != 2:
        pytest.skip("ds_mla only supports 16-bit input")
    kv_cache_dtype = "fp8_ds_mla"
    current_platform.seed_everything(seed)
    torch.set_default_device(device)
676
    torch.cuda.set_device(device)
677
678
679

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
680
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
681
682

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
683
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
684
685
686
    entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)

    scale = torch.tensor(1.0, dtype=torch.float32, device=device)
687
688
689
690
691
692
693
694
    kv_cache = _create_mla_cache(
        num_blocks,
        block_size,
        entry_size,
        dtype=torch.uint8,
        kv_cache_dtype=kv_cache_dtype,
        device=device,
    )
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720

    ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
    tile_data = torch.zeros(128, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size

        ref_cache_slice = ref_cache[block_idx, block_offset]
        ref_cache_16bit = ref_cache_slice.view(dtype)
        ref_cache_32bit = ref_cache_slice.view(torch.float32)

        kv_c_data = kv_c[i]
        for tile_idx in range(4):
            tile_start = tile_idx * 128
            tile_end = (tile_idx + 1) * 128
            tile_data[:] = kv_c_data[tile_start:tile_end]

            # tile_scale = tile_data.amax().to(torch.float32) / 448.
            # NOTE: Using torch's amax() gives different results,
            # so this must be manually computed.
            tile_data_float = tile_data.to(torch.float32)
            manual_max = abs(tile_data_float[0])
            for j in range(1, 128):
                manual_max = max(manual_max, abs(tile_data_float[j]))
721
            tile_scale = manual_max / 448.0
722
723
724

            ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale

725
726
727
728
729
730
            ops.convert_fp8(
                ref_cache_slice[tile_start:tile_end],
                tile_data,
                tile_scale.item(),
                kv_dtype="fp8",
            )
731
732
733
734
735
736
737
738
739
740

        for j in range(qk_rope_head_dim):
            ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

741
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
742
743
744
745
746
747
748
749
750
751

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        kv_cache_slice = kv_cache[block_idx, block_offset]
        ref_cache_slice = ref_cache[block_idx, block_offset]

        kv_nope = kv_cache_slice[:kv_lora_rank]
        ref_nope = ref_cache_slice[:kv_lora_rank]
752
753
754
755
756
757
758
759
        kv_scales = kv_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        ref_scales = ref_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
        ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
760
761
762
763
764
765

        torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)


766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_copy_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    num_layers: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)
789
    torch.cuda.set_device(device)
790
791
792
793
794

    entry_size = kv_lora_rank + qk_rope_head_dim

    kv_caches = []
    for _ in range(num_layers):
795
796
797
        kv_cache = _create_mla_cache(
            num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
        )
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
        _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
        kv_caches.append(kv_cache)

    ref_caches = [kv_cache.clone() for kv_cache in kv_caches]

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining, 2 * num_mappings)
    block_mapping = []
    for i in range(num_mappings):
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
814
815
816
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device=device
    ).view(-1, 2)
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853

    for src, dst in block_mapping:
        for ref_cache in ref_caches:
            ref_cache[dst].copy_(ref_cache[src])

    opcheck(
        torch.ops._C_cache_ops.copy_blocks_mla,
        (kv_caches, block_mapping_tensor),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )
    ops.copy_blocks_mla(kv_caches, block_mapping_tensor)

    for kv_cache, ref_cache in zip(kv_caches, ref_caches):
        torch.testing.assert_close(kv_cache, ref_cache)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_swap_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)
854
    torch.cuda.set_device(device)
855
856
857

    entry_size = kv_lora_rank + qk_rope_head_dim

858
859
860
861
862
863
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
    dst_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
864
865
866
867
868
869
870
871
872
873
874

    _fill_mla_cache(src_cache, kv_cache_dtype)
    _fill_mla_cache(dst_cache, kv_cache_dtype)

    src_cache_clone = src_cache.clone()

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining_blocks, num_mappings)
    block_mapping = list(zip(src_blocks, dst_blocks))
875
876
877
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
878
879
880
881
882
883
884
885
886
887
888
889
890
891

    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
        (src_cache, dst_cache, block_mapping_tensor),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)

    for src, dst in block_mapping:
        torch.testing.assert_close(
            src_cache_clone[src].cpu(),
            dst_cache[dst].cpu(),
            msg=f"Block {src} from src should have been swapped to block "
892
893
            f"{dst} in dst_cache.",
        )
894
895
896
897
898
899
900
901
902


@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
903
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
904
905
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
906
907
908
909
910
911
912
913
914
915
916
def test_gather_and_maybe_dequant_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
917
    entry_size = kv_lora_rank + qk_rope_head_dim
918
    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
919
920
921
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
922
923
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

924
925
926
    seq_len_tensor = torch.randint(
        max_seq_len, max_seq_len + 1, (batch_size,), device=device
    )
927
928

    total_tokens = seq_len_tensor.sum()
929
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
930
931
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
932
933
    token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
    token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
934
935
936
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
937
938
939
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
940
941
942
943
944

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

945
    dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
946
947
948
949
950
951
952
953
954
955
956

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
957
958
959
960
961
962
963
            block_data = src_cache[blocks[i]]
            if kv_cache_dtype == "fp8":
                dequantized_block = torch.empty_like(block_data, dtype=dtype)
                ops.convert_fp8(dequantized_block, block_data, scale.item())
                gathered_rows.append(dequantized_block)
            else:
                gathered_rows.append(block_data)
964
        remaining = s - (tot - 1) * block_size
965
966
        last_block_data = src_cache[blocks[-1], :remaining, :]
        if kv_cache_dtype == "fp8":
967
968
            dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype)
            ops.convert_fp8(dequantized_last_block, last_block_data, scale.item())
969
970
971
            gathered_rows.append(dequantized_last_block)
        else:
            gathered_rows.append(last_block_data)
972
973
974
975
976
977

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
978
        torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
979
980
981
982
983
        (
            src_cache,
            dst,
            block_table,
            cu_seq_lens,
984
985
            token_to_seq,
            total_tokens,
986
987
988
989
            kv_cache_dtype,
            scale,
            None,
        ),
990
991
992
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

993
994
995
996
997
    ops.gather_and_maybe_dequant_cache(
        src_cache,
        dst,
        block_table,
        cu_seq_lens,
998
999
        token_to_seq,
        total_tokens,
1000
1001
1002
1003
        kv_cache_dtype,
        scale,
        None,
    )
1004
    torch.testing.assert_close(dst, expected)
Thien Tran's avatar
Thien Tran committed
1005
1006


1007
1008
1009
1010
1011
1012
1013
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
1014
1015
1016
@pytest.mark.parametrize(
    "kv_cache_dtype", ["auto"]
)  # You can also test "fp8" if needed.
1017
1018
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
def test_cp_gather_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
1030
    entry_size = kv_lora_rank + qk_rope_head_dim
1031
1032
1033
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
1034
1035
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

1036
    seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
1037
1038

    total_tokens = seq_len_tensor.sum()
1039
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
1040
1041
1042
1043
1044
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
1045
1046
1047
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
1048
1049
1050
1051
1052

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

1053
    dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device)
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
            gathered_rows.append(src_cache[blocks[i]])
        remaining = s - (tot - 1) * block_size
        gathered_rows.append(src_cache[blocks[-1], :remaining, :])

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
        torch.ops._C_cache_ops.cp_gather_cache,
        (src_cache, dst, block_table, cu_seq_lens, batch_size, None),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
    torch.testing.assert_close(dst, expected)


Thien Tran's avatar
Thien Tran committed
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
) -> None:
    device = "cpu"
    kv_cache_dtype = "auto"
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
1109
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
Thien Tran's avatar
Thien Tran committed
1110
1111

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
1112
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
Thien Tran's avatar
Thien Tran committed
1113
1114
1115
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
1116
1117
1118
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
Thien Tran's avatar
Thien Tran committed
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
1130
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
Thien Tran's avatar
Thien Tran committed
1131
1132
1133
1134
1135
1136
1137
1138
1139
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

1140
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
Thien Tran's avatar
Thien Tran committed
1141
    torch.testing.assert_close(kv_cache, ref_kv_cache)