test_cache.py 34.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import random

6
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch

9
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
from vllm import _custom_ops as ops
11
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
12
from vllm.platforms import current_platform
13
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
16
DTYPES = [torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
17
NUM_TOKENS = [42]  # Arbitrary values for testing
18
NUM_LAYERS = [1]  # Arbitrary values for testing
19
NUM_HEADS = [8]  # Arbitrary values for testing
20
HEAD_SIZES = [64, 80, 256]
21
BLOCK_SIZES = [8, 16, 32]
22
CACHE_LAYOUTS = ["NHD", "HND"]
23
KV_SCALE_TYPES = ["tensor", "attn_head"]
24

25
26
27
28
29
30
31
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
QK_ROPE_HEAD_DIMS = [64]
NUM_TOKENS_MLA = [42]
BLOCK_SIZES_MLA = [16]
NUM_BLOCKS_MLA = [8]

32
33
34
35
# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

36
NUM_MAPPINGS = [256]  # Arbitrary values for testing
37
SEEDS = [0]
38
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
39
40

# We assume fp8 is always enabled for testing.
zhuwenwen's avatar
zhuwenwen committed
41
42
# KV_CACHE_DTYPE = ["auto", "fp8"] 
KV_CACHE_DTYPE = ["auto"] 
43

44
45
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]

46
47
48
49
50
51
52
53

@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
54
@pytest.mark.parametrize("device", CUDA_DEVICES)
55
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
56
@torch.inference_mode()
57
58
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
63
64
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
65
    seed: int,
66
    device: str,
67
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
68
) -> None:
Joe's avatar
Joe committed
69
70
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
71
    set_random_seed(seed)
72
    torch.set_default_device(device)
73
    torch.cuda.set_device(device)
74
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
75
    num_slots = block_size * num_blocks
76
77
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
78
79

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
    _, key, value = qkv.unbind(dim=1)

82
    # Create the KV caches.
83
84
85
86
87
88
89
90
91
92
93
    key_caches, value_caches = kv_cache_factory(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        device,
    )
94
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
95

96
97
98
99
    # Using default kv_scale
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)

100
    # Clone the KV caches.
101
102
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
103
        ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
104
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
105
        ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
106
107
108
109
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

110
    # Call the reshape_and_cache kernel.
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    opcheck(
        torch.ops._C_cache_ops.reshape_and_cache,
        (
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        ),
        cond=(head_size == HEAD_SIZES[0]),
    )
    ops.reshape_and_cache(
        key,
        value,
        key_cache,
        value_cache,
        slot_mapping,
        kv_cache_dtype,
        k_scale,
        v_scale,
    )
135
136
137

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
138
        ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
139
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
140
        ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
Woosuk Kwon's avatar
Woosuk Kwon committed
141

142
143
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
144
145
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
146
    block_offsets = slot_mapping % block_size
147
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
148
    for i in range(num_tokens):
149
        block_idx = block_indices_lst[i]
150
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
151
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
152
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
153

154
    if kv_cache_dtype == "fp8":
155
156
157
158
159
160
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
161
    else:
162
163
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
164
        
165
166
167
168
169
170
171
172
173
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
174
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
175
@pytest.mark.parametrize("kv_scale_type", KV_SCALE_TYPES)
176
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
177
178
179
180
181
182
183
184
185
186
187
188
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
189
    kv_cache_layout: str,
190
    kv_scale_type: str,
191
    implementation: str,
192
) -> None:
193
    set_random_seed(seed)
194
    torch.set_default_device(device)
195
    torch.cuda.set_device(device)
196
197
198
    assert implementation in ["cuda", "triton"]
    if implementation == "triton" and kv_cache_layout == "HND":
        pytest.skip("Triton implementation only supports NHD layout.")
199

200
201
202
    if kv_scale_type == "attn_head" and implementation != "cuda":
        pytest.skip("Only CUDA implementation supports attn_head scaling.")

203
204
205
206
    # fp8 conversion requires continugous memory buffer. Reduce the number of
    # blocks and tokens to consume less memory.
    num_tokens = num_tokens // 2
    num_blocks = num_blocks // 2
207
208
    # Create a random slot mapping.
    num_slots = block_size * num_blocks
209
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
210
211
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
212
213
214
215
216
217
218
219
220
221
222
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
223
        device=device,
224
        cache_layout=kv_cache_layout,
225
    )
226
    key_cache, value_cache = key_caches[0], value_caches[0]
227
228
    del key_caches
    del value_caches
229

230
231
232
233
234
235
    if kv_scale_type == "tensor":
        k_scale = (key.amax() / 64.0).to(torch.float32)
        v_scale = (value.amax() / 64.0).to(torch.float32)
    else:  # "attn_head"
        k_scale = (key.amax(dim=(0, 2)) / 64.0).to(torch.float32)
        v_scale = (value.amax(dim=(0, 2)) / 64.0).to(torch.float32)
236

237
238
239
240
241
242
243
    def permute_and_compact(x):
        y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
        return y.contiguous()

    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)

244
    def convert_fp8_local(output, input, scale, kv_dtype):
245
        fp8_input = input.view(current_platform.fp8_dtype())
246
247
248
249
250
251
252
253
254
255
256
257
        if scale.numel() == 1:  # per-tensor
            result = scaled_dequantize(
                fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype
            ).reshape(*input.shape)
        else:  # per-head: broadcast scale along the head dimension
            # Original code uses dim 2 for NHD, dim 1 for HND
            if kv_cache_layout == "NHD":
                result = fp8_input.to(output.dtype) * scale.view(1, 1, -1, 1)
            else:
                result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1)
        output.copy_(result)

258
    # Clone the KV caches.
259
    if kv_cache_dtype == "fp8":
260
        cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
261
        convert_fp8_local(cloned_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
262
        cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
263
264
        convert_fp8_local(
            cloned_value_cache, value_cache_compact, v_scale, kv_cache_dtype
265
        )
266
    else:
267
268
        cloned_key_cache = key_cache_compact.clone()
        cloned_value_cache = value_cache_compact.clone()
269
    # Call the reshape_and_cache kernel.
270
    if implementation == "cuda":
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        opcheck(
            torch.ops._C_cache_ops.reshape_and_cache_flash,
            (
                key,
                value,
                key_cache,
                value_cache,
                slot_mapping,
                kv_cache_dtype,
                k_scale,
                v_scale,
            ),
            cond=(head_size == HEAD_SIZES[0]),
        )
        ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
295
    elif implementation == "triton":
296
        from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
297
298
299
300
301
302
303
304
305
306
307
308
309
            triton_reshape_and_cache_flash,
        )

        triton_reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
310
311
    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)
312

313
    if kv_cache_dtype == "fp8":
314
        result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
315
        convert_fp8_local(result_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
316
        result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
317
        convert_fp8_local(
318
319
            result_value_cache,
            value_cache_compact,
320
321
            v_scale,
            kv_cache_dtype,
322
        )
323
324

    # Run the reference implementation.
325
326
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
327
    block_offsets = slot_mapping % block_size
328
    block_offsets_lst = block_offsets.cpu().tolist()
329
    for i in range(num_tokens):
330
        block_idx = block_indices_lst[i]
331
        block_offset = block_offsets_lst[i]
332
333
334
335
336
337
        if kv_cache_layout == "NHD":
            cloned_key_cache[block_idx, block_offset, :, :] = key[i]
            cloned_value_cache[block_idx, block_offset, :, :] = value[i]
        else:
            cloned_key_cache[block_idx, :, block_offset, :] = key[i]
            cloned_value_cache[block_idx, :, block_offset, :] = value[i]
338

339
    if kv_cache_dtype == "fp8":
340
341
342
343
344
345
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
346
    else:
347
348
        torch.testing.assert_close(key_cache_compact, cloned_key_cache)
        torch.testing.assert_close(value_cache_compact, cloned_value_cache)
349
350


Vladimir's avatar
Vladimir committed
351
352
353
354
355
356
357
358
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
359
@pytest.mark.parametrize("device", CUDA_DEVICES)
360
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
361
362
363
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
364
    direction: tuple[str, str],
Vladimir's avatar
Vladimir committed
365
366
367
368
369
370
371
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
372
    device: str,
373
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
374
) -> None:
375
376
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
377
378
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
379

380
    set_random_seed(seed)
381

382
383
    src_device = device if direction[0] == "cuda" else "cpu"
    dst_device = device if direction[1] == "cuda" else "cpu"
Vladimir's avatar
Vladimir committed
384
385
386
387
388
389
390
391
392

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

393
    block_mapping = list(zip(src_blocks, dst_blocks))
394
395
396
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
Vladimir's avatar
Vladimir committed
397
398
399

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
400
401
402
403
404
405
406
407
408
409
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        src_device,
    )
Vladimir's avatar
Vladimir committed
410
411
412

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
413
414
415
416
417
418
419
420
421
422
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        dst_device,
    )
Vladimir's avatar
Vladimir committed
423
424
425
426
427

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
428
    do_opcheck = head_size == HEAD_SIZES[0]
429
430
    src_cache = src_key_caches[0]
    block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
431
432
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
433
434
435
436
437
438
        (
            src_key_caches[0],
            dist_key_caches[0],
            block_size_in_bytes,
            block_mapping_tensor,
        ),
439
440
441
442
        cond=do_opcheck,
    )
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
443
444
445
446
447
448
        (
            src_value_caches[0],
            dist_value_caches[0],
            block_size_in_bytes,
            block_mapping_tensor,
        ),
449
450
451
        cond=do_opcheck,
    )

452
453
454
455
456
457
458
459
460
461
462
463
    ops.swap_blocks(
        src_key_caches[0],
        dist_key_caches[0],
        block_size_in_bytes,
        block_mapping_tensor,
    )
    ops.swap_blocks(
        src_value_caches[0],
        dist_value_caches[0],
        block_size_in_bytes,
        block_mapping_tensor,
    )
Vladimir's avatar
Vladimir committed
464

465
    for src, dst in block_mapping:
466
467
468
469
470
471
        torch.testing.assert_close(
            src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
        )
        torch.testing.assert_close(
            src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
        )
472
473


zhuwenwen's avatar
zhuwenwen committed
474
@pytest.mark.skipif(current_platform.is_rocm(),
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
                    reason="FP8 is not supported on ROCm.")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_e4m3_conversion(
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
493
    set_random_seed(seed)
494
495
496
497
498
499
500
501
502
503
504
505
506
507

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
    ops.convert_fp8(cache_fp8, cache)

    converted_cache = torch.empty_like(cache)
    ops.convert_fp8(converted_cache, cache_fp8)

    torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
508
509
510
511
512
513
514
515
516
517
518


def _create_mla_cache(
    num_blocks: int,
    block_size: int,
    entry_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
) -> torch.Tensor:
    cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
519
520
521
    return torch.zeros(
        num_blocks, block_size, entry_size, dtype=cache_dtype, device=device
    )
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555


def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
    rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype

    vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
    if kv_cache_dtype == "fp8":
        temp = torch.zeros_like(cache)
        ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
        vals = temp
    cache.copy_(vals)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_concat_and_cache_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
556
    set_random_seed(seed)
557
    torch.set_default_device(device)
558
    torch.cuda.set_device(device)
559
560
561

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
562
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
563
564

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
565
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
566
567
568
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
569
570
571
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
572
573
574
575
576
577
578
579
580
581
582
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
583
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
584
585
586
587
588
589
590
591
592
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

593
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
594
595
596

    if kv_cache_dtype == "fp8":
        result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
597
598
599
        ops.convert_fp8(
            result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype
        )
600
        expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
601
602
603
604
        ops.convert_fp8(
            expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype
        )
        torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
605
606
607
608
    else:
        torch.testing.assert_close(kv_cache, ref_kv_cache)


609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
628
629
    if current_platform.is_rocm():
        pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
630
631
632
    if dtype.itemsize != 2:
        pytest.skip("ds_mla only supports 16-bit input")
    kv_cache_dtype = "fp8_ds_mla"
633
    set_random_seed(seed)
634
    torch.set_default_device(device)
635
    torch.cuda.set_device(device)
636
637
638

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
639
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
640
641

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
642
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
643
644
645
    entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)

    scale = torch.tensor(1.0, dtype=torch.float32, device=device)
646
647
648
649
650
651
652
653
    kv_cache = _create_mla_cache(
        num_blocks,
        block_size,
        entry_size,
        dtype=torch.uint8,
        kv_cache_dtype=kv_cache_dtype,
        device=device,
    )
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

    ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
    tile_data = torch.zeros(128, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size

        ref_cache_slice = ref_cache[block_idx, block_offset]
        ref_cache_16bit = ref_cache_slice.view(dtype)
        ref_cache_32bit = ref_cache_slice.view(torch.float32)

        kv_c_data = kv_c[i]
        for tile_idx in range(4):
            tile_start = tile_idx * 128
            tile_end = (tile_idx + 1) * 128
            tile_data[:] = kv_c_data[tile_start:tile_end]

            # tile_scale = tile_data.amax().to(torch.float32) / 448.
            # NOTE: Using torch's amax() gives different results,
            # so this must be manually computed.
            tile_data_float = tile_data.to(torch.float32)
            manual_max = abs(tile_data_float[0])
            for j in range(1, 128):
                manual_max = max(manual_max, abs(tile_data_float[j]))
680
            tile_scale = manual_max / 448.0
681
682
683

            ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale

684
685
686
687
688
689
            ops.convert_fp8(
                ref_cache_slice[tile_start:tile_end],
                tile_data,
                tile_scale.item(),
                kv_dtype="fp8",
            )
690
691
692
693
694
695
696
697
698
699

        for j in range(qk_rope_head_dim):
            ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

700
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
701
702
703
704
705
706
707
708
709
710

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        kv_cache_slice = kv_cache[block_idx, block_offset]
        ref_cache_slice = ref_cache[block_idx, block_offset]

        kv_nope = kv_cache_slice[:kv_lora_rank]
        ref_nope = ref_cache_slice[:kv_lora_rank]
711
712
713
714
715
716
717
718
        kv_scales = kv_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        ref_scales = ref_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
        ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
719
720
721
722
723
724

        torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)


725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_swap_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
744
    set_random_seed(seed)
745
    torch.set_default_device(device)
746
    torch.cuda.set_device(device)
747
748
749

    entry_size = kv_lora_rank + qk_rope_head_dim

750
751
752
753
754
755
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
    dst_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
756
757
758
759
760
761
762
763
764
765
766

    _fill_mla_cache(src_cache, kv_cache_dtype)
    _fill_mla_cache(dst_cache, kv_cache_dtype)

    src_cache_clone = src_cache.clone()

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining_blocks, num_mappings)
    block_mapping = list(zip(src_blocks, dst_blocks))
767
768
769
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
770

771
    block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
772
773
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
774
        (src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor),
775
776
777
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

778
    ops.swap_blocks(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor)
779
780
781
782
783
784

    for src, dst in block_mapping:
        torch.testing.assert_close(
            src_cache_clone[src].cpu(),
            dst_cache[dst].cpu(),
            msg=f"Block {src} from src should have been swapped to block "
785
786
            f"{dst} in dst_cache.",
        )
787
788
789
790
791
792
793
794
795


@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
796
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
797
798
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
799
800
801
802
803
804
805
806
807
808
809
def test_gather_and_maybe_dequant_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
810
    entry_size = kv_lora_rank + qk_rope_head_dim
811
    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
812
813
814
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
815
816
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

817
818
819
    seq_len_tensor = torch.randint(
        max_seq_len, max_seq_len + 1, (batch_size,), device=device
    )
820
821

    total_tokens = seq_len_tensor.sum()
822
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
823
824
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
825
826
    token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
    token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
827
828
829
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
830
831
832
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
833
834
835
836
837

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

838
    dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
839
840
841
842
843
844
845
846
847
848
849

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
850
851
852
853
854
855
856
            block_data = src_cache[blocks[i]]
            if kv_cache_dtype == "fp8":
                dequantized_block = torch.empty_like(block_data, dtype=dtype)
                ops.convert_fp8(dequantized_block, block_data, scale.item())
                gathered_rows.append(dequantized_block)
            else:
                gathered_rows.append(block_data)
857
        remaining = s - (tot - 1) * block_size
858
859
        last_block_data = src_cache[blocks[-1], :remaining, :]
        if kv_cache_dtype == "fp8":
860
861
            dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype)
            ops.convert_fp8(dequantized_last_block, last_block_data, scale.item())
862
863
864
            gathered_rows.append(dequantized_last_block)
        else:
            gathered_rows.append(last_block_data)
865
866
867
868
869
870

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
871
        torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
872
873
874
875
876
        (
            src_cache,
            dst,
            block_table,
            cu_seq_lens,
877
878
            token_to_seq,
            total_tokens,
879
880
881
882
            kv_cache_dtype,
            scale,
            None,
        ),
883
884
885
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

886
887
888
889
890
    ops.gather_and_maybe_dequant_cache(
        src_cache,
        dst,
        block_table,
        cu_seq_lens,
891
892
        token_to_seq,
        total_tokens,
893
894
895
896
        kv_cache_dtype,
        scale,
        None,
    )
897
    torch.testing.assert_close(dst, expected)
Thien Tran's avatar
Thien Tran committed
898
899


900
901
902
903
904
905
906
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
907
908
909
@pytest.mark.parametrize(
    "kv_cache_dtype", ["auto"]
)  # You can also test "fp8" if needed.
910
911
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
912
913
914
915
916
917
918
919
920
921
922
def test_cp_gather_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
923
    entry_size = kv_lora_rank + qk_rope_head_dim
924
925
926
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
927
928
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

929
    seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
930
931

    total_tokens = seq_len_tensor.sum()
932
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
933
934
935
936
937
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
938
939
940
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
941
942
943
944
945

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

946
    dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device)
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
            gathered_rows.append(src_cache[blocks[i]])
        remaining = s - (tot - 1) * block_size
        gathered_rows.append(src_cache[blocks[-1], :remaining, :])

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
967
        torch.ops._C_cache_ops.cp_gather_cache,
968
969
970
971
        (src_cache, dst, block_table, cu_seq_lens, batch_size, None),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

972
    ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
973
    torch.testing.assert_close(dst, expected)
Thien Tran's avatar
Thien Tran committed
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
) -> None:
    device = "cpu"
    kv_cache_dtype = "auto"
997
    set_random_seed(seed)
Thien Tran's avatar
Thien Tran committed
998
999
1000
1001
    torch.set_default_device(device)

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
1002
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
Thien Tran's avatar
Thien Tran committed
1003
1004

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
1005
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
Thien Tran's avatar
Thien Tran committed
1006
1007
1008
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
1009
1010
1011
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
Thien Tran's avatar
Thien Tran committed
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
1023
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
Thien Tran's avatar
Thien Tran committed
1024
1025
1026
1027
1028
1029
1030
1031
1032
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

1033
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
Thien Tran's avatar
Thien Tran committed
1034
    torch.testing.assert_close(kv_cache, ref_kv_cache)