test_cache.py 33.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import random

6
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch

9
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
from vllm import _custom_ops as ops
11
from vllm.platforms import current_platform
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
14
DTYPES = [torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
15
NUM_TOKENS = [42]  # Arbitrary values for testing
16
NUM_LAYERS = [1]  # Arbitrary values for testing
17
NUM_HEADS = [8]  # Arbitrary values for testing
18
HEAD_SIZES = [64, 80, 256]
19
BLOCK_SIZES = [8, 16, 32]
20
CACHE_LAYOUTS = ["NHD", "HND"]
21

22
23
24
25
26
27
28
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
QK_ROPE_HEAD_DIMS = [64]
NUM_TOKENS_MLA = [42]
BLOCK_SIZES_MLA = [16]
NUM_BLOCKS_MLA = [8]

29
30
31
32
# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

33
NUM_MAPPINGS = [256]  # Arbitrary values for testing
34
SEEDS = [0]
35
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
36
37

# We assume fp8 is always enabled for testing.
38
KV_CACHE_DTYPE = ["auto", "fp8"]
39

40
41
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]

42
43
44
45
46
47
48
49

@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
50
@pytest.mark.parametrize("device", CUDA_DEVICES)
51
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
52
@torch.inference_mode()
53
54
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
58
59
60
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
61
    seed: int,
62
    device: str,
63
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
64
) -> None:
Joe's avatar
Joe committed
65
66
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
67
    current_platform.seed_everything(seed)
68
    torch.set_default_device(device)
69
    torch.cuda.set_device(device)
70
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
71
    num_slots = block_size * num_blocks
72
73
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
74
75

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
    _, key, value = qkv.unbind(dim=1)

78
    # Create the KV caches.
79
80
81
82
83
84
85
86
87
88
89
    key_caches, value_caches = kv_cache_factory(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        device,
    )
90
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
91

92
93
94
95
    # Using default kv_scale
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)

96
    # Clone the KV caches.
97
98
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
99
        ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
100
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
101
        ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
102
103
104
105
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

106
    # Call the reshape_and_cache kernel.
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    opcheck(
        torch.ops._C_cache_ops.reshape_and_cache,
        (
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        ),
        cond=(head_size == HEAD_SIZES[0]),
    )
    ops.reshape_and_cache(
        key,
        value,
        key_cache,
        value_cache,
        slot_mapping,
        kv_cache_dtype,
        k_scale,
        v_scale,
    )
131
132
133

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
134
        ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
135
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
136
        ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
Woosuk Kwon's avatar
Woosuk Kwon committed
137

138
139
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
140
141
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
142
    block_offsets = slot_mapping % block_size
143
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
144
    for i in range(num_tokens):
145
        block_idx = block_indices_lst[i]
146
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
147
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
148
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
149

150
    if kv_cache_dtype == "fp8":
151
152
153
154
155
156
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
157
    else:
158
159
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
160
161


162
163
164
165
166
167
168
169
170
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
171
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
172
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
173
174
175
176
177
178
179
180
181
182
183
184
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
185
    kv_cache_layout: str,
186
    implementation: str,
187
) -> None:
188
    current_platform.seed_everything(seed)
189
    torch.set_default_device(device)
190
    torch.cuda.set_device(device)
191
192
193
    assert implementation in ["cuda", "triton"]
    if implementation == "triton" and kv_cache_layout == "HND":
        pytest.skip("Triton implementation only supports NHD layout.")
194

195
196
197
198
    # fp8 conversion requires continugous memory buffer. Reduce the number of
    # blocks and tokens to consume less memory.
    num_tokens = num_tokens // 2
    num_blocks = num_blocks // 2
199
200
    # Create a random slot mapping.
    num_slots = block_size * num_blocks
201
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
202
203
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
204
205
206
207
208
209
210
211
212
213
214
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
215
        device=device,
216
        cache_layout=kv_cache_layout,
217
    )
218
    key_cache, value_cache = key_caches[0], value_caches[0]
219
220
    del key_caches
    del value_caches
221

222
223
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)
224

225
226
227
228
229
230
231
    def permute_and_compact(x):
        y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
        return y.contiguous()

    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)

232
    # Clone the KV caches.
233
    if kv_cache_dtype == "fp8":
234
235
236
237
238
239
240
241
        cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
        ops.convert_fp8(
            cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype
        )
        cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
        ops.convert_fp8(
            cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype
        )
242
    else:
243
244
        cloned_key_cache = key_cache_compact.clone()
        cloned_value_cache = value_cache_compact.clone()
245
    # Call the reshape_and_cache kernel.
246
    if implementation == "cuda":
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        opcheck(
            torch.ops._C_cache_ops.reshape_and_cache_flash,
            (
                key,
                value,
                key_cache,
                value_cache,
                slot_mapping,
                kv_cache_dtype,
                k_scale,
                v_scale,
            ),
            cond=(head_size == HEAD_SIZES[0]),
        )
        ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
271
272
    elif implementation == "triton":
        from vllm.attention.ops.triton_reshape_and_cache_flash import (
273
274
275
276
277
278
279
280
281
282
283
284
285
            triton_reshape_and_cache_flash,
        )

        triton_reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
286
287
    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)
288
289

    if kv_cache_dtype == "fp8":
290
291
292
293
294
295
296
297
298
299
300
        result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
        ops.convert_fp8(
            result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype
        )
        result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
        ops.convert_fp8(
            result_value_cache,
            value_cache_compact,
            v_scale.item(),
            kv_dtype=kv_cache_dtype,
        )
301
302

    # Run the reference implementation.
303
304
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
305
    block_offsets = slot_mapping % block_size
306
    block_offsets_lst = block_offsets.cpu().tolist()
307
    for i in range(num_tokens):
308
        block_idx = block_indices_lst[i]
309
        block_offset = block_offsets_lst[i]
310
311
312
313
314
315
        if kv_cache_layout == "NHD":
            cloned_key_cache[block_idx, block_offset, :, :] = key[i]
            cloned_value_cache[block_idx, block_offset, :, :] = value[i]
        else:
            cloned_key_cache[block_idx, :, block_offset, :] = key[i]
            cloned_value_cache[block_idx, :, block_offset, :] = value[i]
316

317
    if kv_cache_dtype == "fp8":
318
319
320
321
322
323
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
324
    else:
325
326
        torch.testing.assert_close(key_cache_compact, cloned_key_cache)
        torch.testing.assert_close(value_cache_compact, cloned_value_cache)
327
328


Vladimir's avatar
Vladimir committed
329
330
331
332
333
334
335
336
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
337
@pytest.mark.parametrize("device", CUDA_DEVICES)
338
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
339
340
341
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
342
    direction: tuple[str, str],
Vladimir's avatar
Vladimir committed
343
344
345
346
347
348
349
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
350
    device: str,
351
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
352
) -> None:
353
354
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
355
356
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
357

358
    current_platform.seed_everything(seed)
359

360
361
    src_device = device if direction[0] == "cuda" else "cpu"
    dst_device = device if direction[1] == "cuda" else "cpu"
Vladimir's avatar
Vladimir committed
362
363
364
365
366
367
368
369
370

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

371
    block_mapping = list(zip(src_blocks, dst_blocks))
372
373
374
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
Vladimir's avatar
Vladimir committed
375
376
377

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
378
379
380
381
382
383
384
385
386
387
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        src_device,
    )
Vladimir's avatar
Vladimir committed
388
389
390

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
391
392
393
394
395
396
397
398
399
400
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        dst_device,
    )
Vladimir's avatar
Vladimir committed
401
402
403
404
405

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    do_opcheck = head_size == HEAD_SIZES[0]
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
        (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
        cond=do_opcheck,
    )
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
        (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
        cond=do_opcheck,
    )

    ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
Vladimir's avatar
Vladimir committed
420

421
    for src, dst in block_mapping:
422
423
424
425
426
427
        torch.testing.assert_close(
            src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
        )
        torch.testing.assert_close(
            src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
        )
428
429
430
431
432
433
434
435
436
437


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
438
def test_fp8_e4m3_conversion(
439
440
441
442
443
444
445
446
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
447
    current_platform.seed_everything(seed)
448
449
450
451
452
453
454
455

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
456
    ops.convert_fp8(cache_fp8, cache)
457
458

    converted_cache = torch.empty_like(cache)
459
    ops.convert_fp8(converted_cache, cache_fp8)
460

461
    torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
462
463
464
465
466
467
468
469
470
471
472


def _create_mla_cache(
    num_blocks: int,
    block_size: int,
    entry_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
) -> torch.Tensor:
    cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
473
474
475
    return torch.zeros(
        num_blocks, block_size, entry_size, dtype=cache_dtype, device=device
    )
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511


def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
    rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype

    vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
    if kv_cache_dtype == "fp8":
        temp = torch.zeros_like(cache)
        ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
        vals = temp
    cache.copy_(vals)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_concat_and_cache_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)
512
    torch.cuda.set_device(device)
513
514
515

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
516
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
517
518

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
519
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
520
521
522
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
523
524
525
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
526
527
528
529
530
531
532
533
534
535
536
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
537
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
538
539
540
541
542
543
544
545
546
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

547
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
548
549
550

    if kv_cache_dtype == "fp8":
        result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
551
552
553
        ops.convert_fp8(
            result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype
        )
554
        expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
555
556
557
558
        ops.convert_fp8(
            expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype
        )
        torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
559
560
561
562
    else:
        torch.testing.assert_close(kv_cache, ref_kv_cache)


563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
582
583
    if current_platform.is_rocm():
        pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
584
585
586
587
588
    if dtype.itemsize != 2:
        pytest.skip("ds_mla only supports 16-bit input")
    kv_cache_dtype = "fp8_ds_mla"
    current_platform.seed_everything(seed)
    torch.set_default_device(device)
589
    torch.cuda.set_device(device)
590
591
592

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
593
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
594
595

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
596
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
597
598
599
    entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)

    scale = torch.tensor(1.0, dtype=torch.float32, device=device)
600
601
602
603
604
605
606
607
    kv_cache = _create_mla_cache(
        num_blocks,
        block_size,
        entry_size,
        dtype=torch.uint8,
        kv_cache_dtype=kv_cache_dtype,
        device=device,
    )
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633

    ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
    tile_data = torch.zeros(128, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size

        ref_cache_slice = ref_cache[block_idx, block_offset]
        ref_cache_16bit = ref_cache_slice.view(dtype)
        ref_cache_32bit = ref_cache_slice.view(torch.float32)

        kv_c_data = kv_c[i]
        for tile_idx in range(4):
            tile_start = tile_idx * 128
            tile_end = (tile_idx + 1) * 128
            tile_data[:] = kv_c_data[tile_start:tile_end]

            # tile_scale = tile_data.amax().to(torch.float32) / 448.
            # NOTE: Using torch's amax() gives different results,
            # so this must be manually computed.
            tile_data_float = tile_data.to(torch.float32)
            manual_max = abs(tile_data_float[0])
            for j in range(1, 128):
                manual_max = max(manual_max, abs(tile_data_float[j]))
634
            tile_scale = manual_max / 448.0
635
636
637

            ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale

638
639
640
641
642
643
            ops.convert_fp8(
                ref_cache_slice[tile_start:tile_end],
                tile_data,
                tile_scale.item(),
                kv_dtype="fp8",
            )
644
645
646
647
648
649
650
651
652
653

        for j in range(qk_rope_head_dim):
            ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

654
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
655
656
657
658
659
660
661
662
663
664

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        kv_cache_slice = kv_cache[block_idx, block_offset]
        ref_cache_slice = ref_cache[block_idx, block_offset]

        kv_nope = kv_cache_slice[:kv_lora_rank]
        ref_nope = ref_cache_slice[:kv_lora_rank]
665
666
667
668
669
670
671
672
        kv_scales = kv_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        ref_scales = ref_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
        ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
673
674
675
676
677
678

        torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)


679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_swap_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)
700
    torch.cuda.set_device(device)
701
702
703

    entry_size = kv_lora_rank + qk_rope_head_dim

704
705
706
707
708
709
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
    dst_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
710
711
712
713
714
715
716
717
718
719
720

    _fill_mla_cache(src_cache, kv_cache_dtype)
    _fill_mla_cache(dst_cache, kv_cache_dtype)

    src_cache_clone = src_cache.clone()

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining_blocks, num_mappings)
    block_mapping = list(zip(src_blocks, dst_blocks))
721
722
723
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
724
725
726
727
728
729
730
731
732
733
734
735
736
737

    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
        (src_cache, dst_cache, block_mapping_tensor),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)

    for src, dst in block_mapping:
        torch.testing.assert_close(
            src_cache_clone[src].cpu(),
            dst_cache[dst].cpu(),
            msg=f"Block {src} from src should have been swapped to block "
738
739
            f"{dst} in dst_cache.",
        )
740
741
742
743
744
745
746
747
748


@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
749
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
750
751
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
752
753
754
755
756
757
758
759
760
761
762
def test_gather_and_maybe_dequant_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
763
    entry_size = kv_lora_rank + qk_rope_head_dim
764
    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
765
766
767
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
768
769
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

770
771
772
    seq_len_tensor = torch.randint(
        max_seq_len, max_seq_len + 1, (batch_size,), device=device
    )
773
774

    total_tokens = seq_len_tensor.sum()
775
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
776
777
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
778
779
    token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
    token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
780
781
782
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
783
784
785
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
786
787
788
789
790

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

791
    dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
792
793
794
795
796
797
798
799
800
801
802

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
803
804
805
806
807
808
809
            block_data = src_cache[blocks[i]]
            if kv_cache_dtype == "fp8":
                dequantized_block = torch.empty_like(block_data, dtype=dtype)
                ops.convert_fp8(dequantized_block, block_data, scale.item())
                gathered_rows.append(dequantized_block)
            else:
                gathered_rows.append(block_data)
810
        remaining = s - (tot - 1) * block_size
811
812
        last_block_data = src_cache[blocks[-1], :remaining, :]
        if kv_cache_dtype == "fp8":
813
814
            dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype)
            ops.convert_fp8(dequantized_last_block, last_block_data, scale.item())
815
816
817
            gathered_rows.append(dequantized_last_block)
        else:
            gathered_rows.append(last_block_data)
818
819
820
821
822
823

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
824
        torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
825
826
827
828
829
        (
            src_cache,
            dst,
            block_table,
            cu_seq_lens,
830
831
            token_to_seq,
            total_tokens,
832
833
834
835
            kv_cache_dtype,
            scale,
            None,
        ),
836
837
838
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

839
840
841
842
843
    ops.gather_and_maybe_dequant_cache(
        src_cache,
        dst,
        block_table,
        cu_seq_lens,
844
845
        token_to_seq,
        total_tokens,
846
847
848
849
        kv_cache_dtype,
        scale,
        None,
    )
850
    torch.testing.assert_close(dst, expected)
Thien Tran's avatar
Thien Tran committed
851
852


853
854
855
856
857
858
859
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
860
861
862
@pytest.mark.parametrize(
    "kv_cache_dtype", ["auto"]
)  # You can also test "fp8" if needed.
863
864
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
865
866
867
868
869
870
871
872
873
874
875
def test_cp_gather_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
876
    entry_size = kv_lora_rank + qk_rope_head_dim
877
878
879
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
880
881
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

882
    seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
883
884

    total_tokens = seq_len_tensor.sum()
885
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
886
887
888
889
890
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
891
892
893
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
894
895
896
897
898

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

899
    dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device)
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
            gathered_rows.append(src_cache[blocks[i]])
        remaining = s - (tot - 1) * block_size
        gathered_rows.append(src_cache[blocks[-1], :remaining, :])

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
        torch.ops._C_cache_ops.cp_gather_cache,
        (src_cache, dst, block_table, cu_seq_lens, batch_size, None),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
    torch.testing.assert_close(dst, expected)


Thien Tran's avatar
Thien Tran committed
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
) -> None:
    device = "cpu"
    kv_cache_dtype = "auto"
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
955
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
Thien Tran's avatar
Thien Tran committed
956
957

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
958
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
Thien Tran's avatar
Thien Tran committed
959
960
961
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
962
963
964
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
Thien Tran's avatar
Thien Tran committed
965
966
967
968
969
970
971
972
973
974
975
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
976
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
Thien Tran's avatar
Thien Tran committed
977
978
979
980
981
982
983
984
985
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

986
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
Thien Tran's avatar
Thien Tran committed
987
    torch.testing.assert_close(kv_cache, ref_kv_cache)