test_cache.py 38.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import random

6
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch

9
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
from vllm import _custom_ops as ops
11
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
12
from vllm.platforms import current_platform
13
from vllm.utils.torch_utils import nvfp4_kv_cache_split_views, set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
16
DTYPES = [torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
17
NUM_TOKENS = [42]  # Arbitrary values for testing
18
NUM_LAYERS = [1]  # Arbitrary values for testing
19
NUM_HEADS = [8]  # Arbitrary values for testing
20
HEAD_SIZES = [64, 80, 256]
21
BLOCK_SIZES = [8, 16, 32]
22
CACHE_LAYOUTS = ["NHD", "HND"]
23
KV_SCALE_TYPES = ["tensor", "attn_head"]
24

25
# Parameters for MLA tests.
26
KV_LORA_RANKS = [256, 512]
27
28
29
30
31
QK_ROPE_HEAD_DIMS = [64]
NUM_TOKENS_MLA = [42]
BLOCK_SIZES_MLA = [16]
NUM_BLOCKS_MLA = [8]

32
33
34
35
# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

36
NUM_MAPPINGS = [256]  # Arbitrary values for testing
37
SEEDS = [0]
38
39
40
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
41
42

# We assume fp8 is always enabled for testing.
43
KV_CACHE_DTYPE = ["auto", "fp8"]
44

45
46
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]

47
48
49
50
51
52
53
54

@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
55
@pytest.mark.parametrize("device", CUDA_DEVICES)
56
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
57
@torch.inference_mode()
58
59
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
64
65
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
66
    seed: int,
67
    device: str,
68
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
69
) -> None:
Joe's avatar
Joe committed
70
71
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
72
    set_random_seed(seed)
73
    torch.set_default_device(device)
74
    torch.accelerator.set_device_index(device)
75
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
76
    num_slots = block_size * num_blocks
77
78
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
79
80

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
    _, key, value = qkv.unbind(dim=1)

83
    # Create the KV caches.
84
85
86
87
88
89
90
91
92
93
94
    key_caches, value_caches = kv_cache_factory(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        device,
    )
95
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
96

97
98
99
100
    # Using default kv_scale
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)

101
    # Clone the KV caches.
102
103
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
104
        ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
105
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
106
        ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
107
108
109
110
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

111
    # Call the reshape_and_cache kernel.
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    opcheck(
        torch.ops._C_cache_ops.reshape_and_cache,
        (
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        ),
        cond=(head_size == HEAD_SIZES[0]),
    )
    ops.reshape_and_cache(
        key,
        value,
        key_cache,
        value_cache,
        slot_mapping,
        kv_cache_dtype,
        k_scale,
        v_scale,
    )
136
137
138

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
139
        ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
140
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
141
        ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
Woosuk Kwon's avatar
Woosuk Kwon committed
142

143
144
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
145
146
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
147
    block_offsets = slot_mapping % block_size
148
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
149
    for i in range(num_tokens):
150
        block_idx = block_indices_lst[i]
151
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
152
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
153
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
154

155
    if kv_cache_dtype == "fp8":
156
157
158
159
160
161
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
162
    else:
163
164
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
165
166


167
168
169
170
171
172
173
174
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
175
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE + ["nvfp4"])
176
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
177
@pytest.mark.parametrize("kv_scale_type", KV_SCALE_TYPES)
178
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
179
180
181
182
183
184
185
186
187
188
189
190
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
191
    kv_cache_layout: str,
192
    kv_scale_type: str,
193
    implementation: str,
194
) -> None:
195
    set_random_seed(seed)
196
    torch.set_default_device(device)
197
    torch.accelerator.set_device_index(device)
198
199
200
    assert implementation in ["cuda", "triton"]
    if implementation == "triton" and kv_cache_layout == "HND":
        pytest.skip("Triton implementation only supports NHD layout.")
201

202
203
204
    if kv_scale_type == "attn_head" and implementation != "cuda":
        pytest.skip("Only CUDA implementation supports attn_head scaling.")

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    if kv_cache_dtype == "nvfp4":
        if not current_platform.has_device_capability(100):
            pytest.skip("NVFP4 requires compute capability >= 10.0 (Blackwell).")
        if implementation != "cuda":
            pytest.skip("NVFP4 only supports CUDA implementation.")
        if kv_scale_type != "tensor":
            pytest.skip("NVFP4 only supports per-tensor scaling.")
        if head_size % 16 != 0:
            pytest.skip("NVFP4 requires head_size divisible by 16.")
        if (head_size // 16) % 4 != 0:
            pytest.skip(
                "NVFP4 requires (head_size // 16) divisible by 4 "
                "for 4x4 block scale swizzle."
            )
        if block_size % 4 != 0:
            pytest.skip("NVFP4 requires block_size divisible by 4.")
        if dtype not in (torch.float16, torch.bfloat16):
            pytest.skip("NVFP4 quantization only supports fp16/bf16 input.")

224
225
226
227
    # fp8 conversion requires continugous memory buffer. Reduce the number of
    # blocks and tokens to consume less memory.
    num_tokens = num_tokens // 2
    num_blocks = num_blocks // 2
228
229
    # Create a random slot mapping.
    num_slots = block_size * num_blocks
230
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
231
232
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
233
234
235
236
237
238
239
240
241
242
243
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
244
        device=device,
245
        cache_layout=kv_cache_layout,
246
    )
247
    key_cache, value_cache = key_caches[0], value_caches[0]
248
249
    del key_caches
    del value_caches
250

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    # For nvfp4, the factory returns kv[:, 0] and kv[:, 1] like all dtypes.
    # Split views are still needed for dequant verification.
    key_scale_cache = None
    value_scale_cache = None
    nvfp4_key_data = None
    nvfp4_value_data = None
    if kv_cache_dtype == "nvfp4":
        (nvfp4_key_data,), (key_scale_cache,) = nvfp4_kv_cache_split_views(key_cache)
        (nvfp4_value_data,), (value_scale_cache,) = nvfp4_kv_cache_split_views(
            value_cache
        )

    if kv_cache_dtype == "nvfp4":
        # Global scale = amax / 448 (per-tensor)
        k_scale = (key.abs().amax() / 448.0).to(torch.float32)
        v_scale = (value.abs().amax() / 448.0).to(torch.float32)
    elif kv_scale_type == "tensor":
268
269
270
271
272
        k_scale = (key.amax() / 64.0).to(torch.float32)
        v_scale = (value.amax() / 64.0).to(torch.float32)
    else:  # "attn_head"
        k_scale = (key.amax(dim=(0, 2)) / 64.0).to(torch.float32)
        v_scale = (value.amax(dim=(0, 2)) / 64.0).to(torch.float32)
273

274
275
276
277
    def permute_and_compact(x):
        y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
        return y.contiguous()

278
279
280
    if kv_cache_dtype != "nvfp4":
        key_cache_compact = permute_and_compact(key_cache)
        value_cache_compact = permute_and_compact(value_cache)
281

282
    def convert_fp8_local(output, input, scale, kv_dtype):
283
        fp8_input = input.view(current_platform.fp8_dtype())
284
285
286
287
288
289
290
291
292
293
294
295
        if scale.numel() == 1:  # per-tensor
            result = scaled_dequantize(
                fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype
            ).reshape(*input.shape)
        else:  # per-head: broadcast scale along the head dimension
            # Original code uses dim 2 for NHD, dim 1 for HND
            if kv_cache_layout == "NHD":
                result = fp8_input.to(output.dtype) * scale.view(1, 1, -1, 1)
            else:
                result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1)
        output.copy_(result)

296
    # Clone the KV caches (for non-nvfp4, used as reference baseline).
297
    if kv_cache_dtype == "fp8":
298
        cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
299
        convert_fp8_local(cloned_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
300
        cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
301
302
        convert_fp8_local(
            cloned_value_cache, value_cache_compact, v_scale, kv_cache_dtype
303
        )
304
    elif kv_cache_dtype != "nvfp4":
305
306
        cloned_key_cache = key_cache_compact.clone()
        cloned_value_cache = value_cache_compact.clone()
307

308
    # Call the reshape_and_cache kernel.
309
    if implementation == "cuda":
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        if kv_cache_dtype != "nvfp4":
            opcheck(
                torch.ops._C_cache_ops.reshape_and_cache_flash,
                (
                    key,
                    value,
                    key_cache,
                    value_cache,
                    slot_mapping,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                ),
                cond=(head_size == HEAD_SIZES[0]),
            )
325
326
327
328
329
330
331
332
333
334
        ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
335
    elif implementation == "triton":
336
        from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
337
338
339
340
341
342
343
344
345
346
347
348
349
            triton_reshape_and_cache_flash,
        )

        triton_reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389

    if kv_cache_dtype == "nvfp4":
        # Verify NVFP4 by dequantizing the entire cache and comparing
        # the written positions against original bf16 values.
        # Same pattern as FP8: dequant whole cache, then extract and compare.
        from tests.kernels.quantization.nvfp4_utils import (
            dequant_nvfp4_kv_cache,
        )

        def dequant_nvfp4_cache_nhd(data_cache, scale_cache, global_scale):
            # data_cache:  [N, T, H, data_dim]  NHD (contiguous inner dims)
            # scale_cache: [N, T, H, scale_dim] NHD (contiguous inner dims)
            # Permute to HND layout for the dequant utility.
            data_hnd = data_cache.permute(0, 2, 1, 3)
            scale_hnd = scale_cache.permute(0, 2, 1, 3)
            result_hnd = dequant_nvfp4_kv_cache(
                data_hnd, scale_hnd, global_scale, head_size, block_size
            )
            return result_hnd.permute(0, 2, 1, 3)  # back to [N, T, H, D]

        result_key_cache = dequant_nvfp4_cache_nhd(
            nvfp4_key_data, key_scale_cache, k_scale.item()
        )
        result_value_cache = dequant_nvfp4_cache_nhd(
            nvfp4_value_data, value_scale_cache, v_scale.item()
        )

        # Flatten [num_blocks, block_size] → [num_slots] and index by slot_mapping.
        num_slots = num_blocks * block_size
        result_key_flat = result_key_cache.reshape(num_slots, num_heads, head_size)
        result_value_flat = result_value_cache.reshape(num_slots, num_heads, head_size)

        torch.testing.assert_close(
            result_key_flat[slot_mapping], key.float(), atol=1.5, rtol=0.5
        )
        torch.testing.assert_close(
            result_value_flat[slot_mapping], value.float(), atol=1.5, rtol=0.5
        )
        return

390
391
    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)
392
393

    if kv_cache_dtype == "fp8":
394
        result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
395
        convert_fp8_local(result_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
396
        result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
397
        convert_fp8_local(
398
399
            result_value_cache,
            value_cache_compact,
400
401
            v_scale,
            kv_cache_dtype,
402
        )
403
404

    # Run the reference implementation.
405
406
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
407
    block_offsets = slot_mapping % block_size
408
    block_offsets_lst = block_offsets.cpu().tolist()
409
    for i in range(num_tokens):
410
        block_idx = block_indices_lst[i]
411
        block_offset = block_offsets_lst[i]
412
413
414
415
416
417
        if kv_cache_layout == "NHD":
            cloned_key_cache[block_idx, block_offset, :, :] = key[i]
            cloned_value_cache[block_idx, block_offset, :, :] = value[i]
        else:
            cloned_key_cache[block_idx, :, block_offset, :] = key[i]
            cloned_value_cache[block_idx, :, block_offset, :] = value[i]
418

419
    if kv_cache_dtype == "fp8":
420
421
422
423
424
425
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
426
    else:
427
428
        torch.testing.assert_close(key_cache_compact, cloned_key_cache)
        torch.testing.assert_close(value_cache_compact, cloned_value_cache)
429
430


Vladimir's avatar
Vladimir committed
431
432
433
434
435
436
437
438
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
439
@pytest.mark.parametrize("device", CUDA_DEVICES)
440
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
441
442
443
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
444
    direction: tuple[str, str],
Vladimir's avatar
Vladimir committed
445
446
447
448
449
450
451
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
452
    device: str,
453
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
454
) -> None:
455
456
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
457
458
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
459

460
    set_random_seed(seed)
461

462
463
    src_device = device if direction[0] == "cuda" else "cpu"
    dst_device = device if direction[1] == "cuda" else "cpu"
Vladimir's avatar
Vladimir committed
464
465
466
467
468
469
470
471
472

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

473
    block_mapping = list(zip(src_blocks, dst_blocks))
474
475
476
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
Vladimir's avatar
Vladimir committed
477
478
479

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
480
481
482
483
484
485
486
487
488
489
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        src_device,
    )
Vladimir's avatar
Vladimir committed
490
491
492

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
493
494
495
496
497
498
499
500
501
502
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        dst_device,
    )
Vladimir's avatar
Vladimir committed
503
504
505
506
507

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
508
    do_opcheck = head_size == HEAD_SIZES[0]
509
510
    src_cache = src_key_caches[0]
    block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
511
512
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
513
514
515
516
517
518
        (
            src_key_caches[0],
            dist_key_caches[0],
            block_size_in_bytes,
            block_mapping_tensor,
        ),
519
520
521
522
        cond=do_opcheck,
    )
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
523
524
525
526
527
528
        (
            src_value_caches[0],
            dist_value_caches[0],
            block_size_in_bytes,
            block_mapping_tensor,
        ),
529
530
531
        cond=do_opcheck,
    )

532
533
534
535
536
537
538
539
540
541
542
543
    ops.swap_blocks(
        src_key_caches[0],
        dist_key_caches[0],
        block_size_in_bytes,
        block_mapping_tensor,
    )
    ops.swap_blocks(
        src_value_caches[0],
        dist_value_caches[0],
        block_size_in_bytes,
        block_mapping_tensor,
    )
Vladimir's avatar
Vladimir committed
544

545
    for src, dst in block_mapping:
546
547
548
549
550
551
        torch.testing.assert_close(
            src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
        )
        torch.testing.assert_close(
            src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
        )
552
553
554
555
556
557
558
559
560
561


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
562
def test_fp8_e4m3_conversion(
563
564
565
566
567
568
569
570
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
571
    set_random_seed(seed)
572
573
574
575
576
577
578
579

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
580
    ops.convert_fp8(cache_fp8, cache)
581
582

    converted_cache = torch.empty_like(cache)
583
    ops.convert_fp8(converted_cache, cache_fp8)
584

585
    torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
586
587
588
589
590
591
592
593
594
595
596


def _create_mla_cache(
    num_blocks: int,
    block_size: int,
    entry_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
) -> torch.Tensor:
    cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
597
598
599
    return torch.zeros(
        num_blocks, block_size, entry_size, dtype=cache_dtype, device=device
    )
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633


def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
    rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype

    vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
    if kv_cache_dtype == "fp8":
        temp = torch.zeros_like(cache)
        ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
        vals = temp
    cache.copy_(vals)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_concat_and_cache_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
634
    set_random_seed(seed)
635
    torch.set_default_device(device)
636
    torch.accelerator.set_device_index(device)
637
638
639

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
640
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
641
642

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
643
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
644
645
646
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
647
648
649
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
650
651
652
653
654
655
656
657
658
659
660
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
661
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
662
663
664
665
666
667
668
669
670
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

671
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
672
673
674

    if kv_cache_dtype == "fp8":
        result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
675
676
677
        ops.convert_fp8(
            result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype
        )
678
        expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
679
680
681
682
        ops.convert_fp8(
            expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype
        )
        torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
683
684
685
686
    else:
        torch.testing.assert_close(kv_cache, ref_kv_cache)


687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
706
707
    if current_platform.is_rocm():
        pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
708
709
    if dtype.itemsize != 2:
        pytest.skip("ds_mla only supports 16-bit input")
710
711
    if kv_lora_rank != 512:
        pytest.skip("fp8_ds_mla requires kv_lora_rank == 512")
712
    kv_cache_dtype = "fp8_ds_mla"
713
    set_random_seed(seed)
714
    torch.set_default_device(device)
715
    torch.accelerator.set_device_index(device)
716
717
718

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
719
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
720
721

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
722
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
723
724
725
    entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)

    scale = torch.tensor(1.0, dtype=torch.float32, device=device)
726
727
728
729
730
731
732
733
    kv_cache = _create_mla_cache(
        num_blocks,
        block_size,
        entry_size,
        dtype=torch.uint8,
        kv_cache_dtype=kv_cache_dtype,
        device=device,
    )
734
735
736
737
738
739
740
741
742
743
744
745
746
747

    ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
    tile_data = torch.zeros(128, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size

        ref_cache_slice = ref_cache[block_idx, block_offset]
        ref_cache_16bit = ref_cache_slice.view(dtype)
        ref_cache_32bit = ref_cache_slice.view(torch.float32)

        kv_c_data = kv_c[i]
748
749
        num_tiles = kv_lora_rank // 128
        for tile_idx in range(num_tiles):
750
751
752
753
754
755
756
757
758
759
760
            tile_start = tile_idx * 128
            tile_end = (tile_idx + 1) * 128
            tile_data[:] = kv_c_data[tile_start:tile_end]

            # tile_scale = tile_data.amax().to(torch.float32) / 448.
            # NOTE: Using torch's amax() gives different results,
            # so this must be manually computed.
            tile_data_float = tile_data.to(torch.float32)
            manual_max = abs(tile_data_float[0])
            for j in range(1, 128):
                manual_max = max(manual_max, abs(tile_data_float[j]))
761
            tile_scale = manual_max / 448.0
762
763
764

            ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale

765
766
767
768
769
770
            ops.convert_fp8(
                ref_cache_slice[tile_start:tile_end],
                tile_data,
                tile_scale.item(),
                kv_dtype="fp8",
            )
771
772
773
774
775
776
777
778
779
780

        for j in range(qk_rope_head_dim):
            ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

781
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
782
783
784
785
786
787
788
789
790
791

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        kv_cache_slice = kv_cache[block_idx, block_offset]
        ref_cache_slice = ref_cache[block_idx, block_offset]

        kv_nope = kv_cache_slice[:kv_lora_rank]
        ref_nope = ref_cache_slice[:kv_lora_rank]
792
793
794
795
796
797
798
799
        kv_scales = kv_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        ref_scales = ref_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
        ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
800
801
802
803
804
805

        torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)


806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_swap_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
825
    set_random_seed(seed)
826
    torch.set_default_device(device)
827
    torch.accelerator.set_device_index(device)
828
829
830

    entry_size = kv_lora_rank + qk_rope_head_dim

831
832
833
834
835
836
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
    dst_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
837
838
839
840
841
842
843
844
845
846
847

    _fill_mla_cache(src_cache, kv_cache_dtype)
    _fill_mla_cache(dst_cache, kv_cache_dtype)

    src_cache_clone = src_cache.clone()

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining_blocks, num_mappings)
    block_mapping = list(zip(src_blocks, dst_blocks))
848
849
850
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
851

852
    block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
853
854
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
855
        (src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor),
856
857
858
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

859
    ops.swap_blocks(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor)
860
861
862
863
864
865

    for src, dst in block_mapping:
        torch.testing.assert_close(
            src_cache_clone[src].cpu(),
            dst_cache[dst].cpu(),
            msg=f"Block {src} from src should have been swapped to block "
866
867
            f"{dst} in dst_cache.",
        )
868
869
870
871
872
873
874
875
876


@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
877
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
878
879
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
880
881
882
883
884
885
886
887
888
889
890
def test_gather_and_maybe_dequant_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
891
    entry_size = kv_lora_rank + qk_rope_head_dim
892
    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
893
894
895
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
896
897
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

898
899
900
    seq_len_tensor = torch.randint(
        max_seq_len, max_seq_len + 1, (batch_size,), device=device
    )
901
902

    total_tokens = seq_len_tensor.sum()
903
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
904
905
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
906
907
    token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
    token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
908
909
910
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
911
912
913
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
914
915
916
917
918

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

919
    dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
920
921
922
923
924
925
926
927
928
929
930

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
931
932
933
934
935
936
937
            block_data = src_cache[blocks[i]]
            if kv_cache_dtype == "fp8":
                dequantized_block = torch.empty_like(block_data, dtype=dtype)
                ops.convert_fp8(dequantized_block, block_data, scale.item())
                gathered_rows.append(dequantized_block)
            else:
                gathered_rows.append(block_data)
938
        remaining = s - (tot - 1) * block_size
939
940
        last_block_data = src_cache[blocks[-1], :remaining, :]
        if kv_cache_dtype == "fp8":
941
942
            dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype)
            ops.convert_fp8(dequantized_last_block, last_block_data, scale.item())
943
944
945
            gathered_rows.append(dequantized_last_block)
        else:
            gathered_rows.append(last_block_data)
946
947
948
949
950
951

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
952
        torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
953
954
955
956
957
        (
            src_cache,
            dst,
            block_table,
            cu_seq_lens,
958
959
            token_to_seq,
            total_tokens,
960
961
962
963
            kv_cache_dtype,
            scale,
            None,
        ),
964
965
966
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

967
968
969
970
971
    ops.gather_and_maybe_dequant_cache(
        src_cache,
        dst,
        block_table,
        cu_seq_lens,
972
973
        token_to_seq,
        total_tokens,
974
975
976
977
        kv_cache_dtype,
        scale,
        None,
    )
978
    torch.testing.assert_close(dst, expected)
Thien Tran's avatar
Thien Tran committed
979
980


981
982
983
984
985
986
987
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
988
989
990
@pytest.mark.parametrize(
    "kv_cache_dtype", ["auto"]
)  # You can also test "fp8" if needed.
991
992
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
993
994
995
996
997
998
999
1000
1001
1002
1003
def test_cp_gather_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
1004
    entry_size = kv_lora_rank + qk_rope_head_dim
1005
1006
1007
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
1008
1009
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

1010
    seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
1011
1012

    total_tokens = seq_len_tensor.sum()
1013
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
1014
1015
1016
1017
1018
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
1019
1020
1021
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
1022
1023
1024
1025
1026

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

1027
    dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device)
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
            gathered_rows.append(src_cache[blocks[i]])
        remaining = s - (tot - 1) * block_size
        gathered_rows.append(src_cache[blocks[-1], :remaining, :])

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
        torch.ops._C_cache_ops.cp_gather_cache,
        (src_cache, dst, block_table, cu_seq_lens, batch_size, None),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
    torch.testing.assert_close(dst, expected)


Thien Tran's avatar
Thien Tran committed
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
) -> None:
    device = "cpu"
    kv_cache_dtype = "auto"
1078
    set_random_seed(seed)
Thien Tran's avatar
Thien Tran committed
1079
1080
1081
1082
    torch.set_default_device(device)

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
1083
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
Thien Tran's avatar
Thien Tran committed
1084
1085

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
1086
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
Thien Tran's avatar
Thien Tran committed
1087
1088
1089
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
1090
1091
1092
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
Thien Tran's avatar
Thien Tran committed
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
1104
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
Thien Tran's avatar
Thien Tran committed
1105
1106
1107
1108
1109
1110
1111
1112
1113
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

1114
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
Thien Tran's avatar
Thien Tran committed
1115
    torch.testing.assert_close(kv_cache, ref_kv_cache)