test_cache.py 35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import random

6
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch

9
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
from vllm import _custom_ops as ops
11
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
12
from vllm.platforms import current_platform
13
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
16
DTYPES = [torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
17
NUM_TOKENS = [42]  # Arbitrary values for testing
18
NUM_LAYERS = [1]  # Arbitrary values for testing
19
NUM_HEADS = [8]  # Arbitrary values for testing
20
HEAD_SIZES = [64, 80, 256]
21
BLOCK_SIZES = [8, 16, 32]
22
CACHE_LAYOUTS = ["NHD", "HND"]
23
KV_SCALE_TYPES = ["tensor", "attn_head"]
24

25
# Parameters for MLA tests.
26
KV_LORA_RANKS = [256, 512]
27
28
29
30
31
QK_ROPE_HEAD_DIMS = [64]
NUM_TOKENS_MLA = [42]
BLOCK_SIZES_MLA = [16]
NUM_BLOCKS_MLA = [8]

32
33
34
35
# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

36
NUM_MAPPINGS = [256]  # Arbitrary values for testing
37
SEEDS = [0]
38
39
40
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
41
42

# We assume fp8 is always enabled for testing.
43
KV_CACHE_DTYPE = ["auto", "fp8"]
44

45
46
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]

47
48
49
50
51
52
53
54

@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
55
@pytest.mark.parametrize("device", CUDA_DEVICES)
56
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
57
@torch.inference_mode()
58
59
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
64
65
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
66
    seed: int,
67
    device: str,
68
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
69
) -> None:
Joe's avatar
Joe committed
70
71
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
72
    set_random_seed(seed)
73
    torch.set_default_device(device)
74
    torch.accelerator.set_device_index(device)
75
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
76
    num_slots = block_size * num_blocks
77
78
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
79
80

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
    _, key, value = qkv.unbind(dim=1)

83
    # Create the KV caches.
84
85
86
87
88
89
90
91
92
93
94
    key_caches, value_caches = kv_cache_factory(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        device,
    )
95
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
96

97
98
99
100
    # Using default kv_scale
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)

101
    # Clone the KV caches.
102
103
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
104
        ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
105
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
106
        ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
107
108
109
110
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

111
    # Call the reshape_and_cache kernel.
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    opcheck(
        torch.ops._C_cache_ops.reshape_and_cache,
        (
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        ),
        cond=(head_size == HEAD_SIZES[0]),
    )
    ops.reshape_and_cache(
        key,
        value,
        key_cache,
        value_cache,
        slot_mapping,
        kv_cache_dtype,
        k_scale,
        v_scale,
    )
136
137
138

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
139
        ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
140
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
141
        ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
Woosuk Kwon's avatar
Woosuk Kwon committed
142

143
144
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
145
146
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
147
    block_offsets = slot_mapping % block_size
148
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
149
    for i in range(num_tokens):
150
        block_idx = block_indices_lst[i]
151
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
152
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
153
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
154

155
    if kv_cache_dtype == "fp8":
156
157
158
159
160
161
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
162
    else:
163
164
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
165
166


167
168
169
170
171
172
173
174
175
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
176
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
177
@pytest.mark.parametrize("kv_scale_type", KV_SCALE_TYPES)
178
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
179
180
181
182
183
184
185
186
187
188
189
190
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
191
    kv_cache_layout: str,
192
    kv_scale_type: str,
193
    implementation: str,
194
) -> None:
195
    set_random_seed(seed)
196
    torch.set_default_device(device)
197
    torch.accelerator.set_device_index(device)
198
199
200
    assert implementation in ["cuda", "triton"]
    if implementation == "triton" and kv_cache_layout == "HND":
        pytest.skip("Triton implementation only supports NHD layout.")
201

202
203
204
    if kv_scale_type == "attn_head" and implementation != "cuda":
        pytest.skip("Only CUDA implementation supports attn_head scaling.")

205
206
207
208
    # fp8 conversion requires continugous memory buffer. Reduce the number of
    # blocks and tokens to consume less memory.
    num_tokens = num_tokens // 2
    num_blocks = num_blocks // 2
209
210
    # Create a random slot mapping.
    num_slots = block_size * num_blocks
211
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
212
213
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
214
215
216
217
218
219
220
221
222
223
224
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
225
        device=device,
226
        cache_layout=kv_cache_layout,
227
    )
228
    key_cache, value_cache = key_caches[0], value_caches[0]
229
230
    del key_caches
    del value_caches
231

232
233
234
235
236
237
    if kv_scale_type == "tensor":
        k_scale = (key.amax() / 64.0).to(torch.float32)
        v_scale = (value.amax() / 64.0).to(torch.float32)
    else:  # "attn_head"
        k_scale = (key.amax(dim=(0, 2)) / 64.0).to(torch.float32)
        v_scale = (value.amax(dim=(0, 2)) / 64.0).to(torch.float32)
238

239
240
241
242
243
244
245
    def permute_and_compact(x):
        y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
        return y.contiguous()

    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)

246
    def convert_fp8_local(output, input, scale, kv_dtype):
247
        fp8_input = input.view(current_platform.fp8_dtype())
248
249
250
251
252
253
254
255
256
257
258
259
        if scale.numel() == 1:  # per-tensor
            result = scaled_dequantize(
                fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype
            ).reshape(*input.shape)
        else:  # per-head: broadcast scale along the head dimension
            # Original code uses dim 2 for NHD, dim 1 for HND
            if kv_cache_layout == "NHD":
                result = fp8_input.to(output.dtype) * scale.view(1, 1, -1, 1)
            else:
                result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1)
        output.copy_(result)

260
    # Clone the KV caches.
261
    if kv_cache_dtype == "fp8":
262
        cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
263
        convert_fp8_local(cloned_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
264
        cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
265
266
        convert_fp8_local(
            cloned_value_cache, value_cache_compact, v_scale, kv_cache_dtype
267
        )
268
    else:
269
270
        cloned_key_cache = key_cache_compact.clone()
        cloned_value_cache = value_cache_compact.clone()
271
    # Call the reshape_and_cache kernel.
272
    if implementation == "cuda":
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        opcheck(
            torch.ops._C_cache_ops.reshape_and_cache_flash,
            (
                key,
                value,
                key_cache,
                value_cache,
                slot_mapping,
                kv_cache_dtype,
                k_scale,
                v_scale,
            ),
            cond=(head_size == HEAD_SIZES[0]),
        )
        ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
297
    elif implementation == "triton":
298
        from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
299
300
301
302
303
304
305
306
307
308
309
310
311
            triton_reshape_and_cache_flash,
        )

        triton_reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
312
313
    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)
314
315

    if kv_cache_dtype == "fp8":
316
        result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
317
        convert_fp8_local(result_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
318
        result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
319
        convert_fp8_local(
320
321
            result_value_cache,
            value_cache_compact,
322
323
            v_scale,
            kv_cache_dtype,
324
        )
325
326

    # Run the reference implementation.
327
328
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
329
    block_offsets = slot_mapping % block_size
330
    block_offsets_lst = block_offsets.cpu().tolist()
331
    for i in range(num_tokens):
332
        block_idx = block_indices_lst[i]
333
        block_offset = block_offsets_lst[i]
334
335
336
337
338
339
        if kv_cache_layout == "NHD":
            cloned_key_cache[block_idx, block_offset, :, :] = key[i]
            cloned_value_cache[block_idx, block_offset, :, :] = value[i]
        else:
            cloned_key_cache[block_idx, :, block_offset, :] = key[i]
            cloned_value_cache[block_idx, :, block_offset, :] = value[i]
340

341
    if kv_cache_dtype == "fp8":
342
343
344
345
346
347
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
348
    else:
349
350
        torch.testing.assert_close(key_cache_compact, cloned_key_cache)
        torch.testing.assert_close(value_cache_compact, cloned_value_cache)
351
352


Vladimir's avatar
Vladimir committed
353
354
355
356
357
358
359
360
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
361
@pytest.mark.parametrize("device", CUDA_DEVICES)
362
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
363
364
365
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
366
    direction: tuple[str, str],
Vladimir's avatar
Vladimir committed
367
368
369
370
371
372
373
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
374
    device: str,
375
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
376
) -> None:
377
378
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
379
380
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
381

382
    set_random_seed(seed)
383

384
385
    src_device = device if direction[0] == "cuda" else "cpu"
    dst_device = device if direction[1] == "cuda" else "cpu"
Vladimir's avatar
Vladimir committed
386
387
388
389
390
391
392
393
394

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

395
    block_mapping = list(zip(src_blocks, dst_blocks))
396
397
398
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
Vladimir's avatar
Vladimir committed
399
400
401

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
402
403
404
405
406
407
408
409
410
411
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        src_device,
    )
Vladimir's avatar
Vladimir committed
412
413
414

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
415
416
417
418
419
420
421
422
423
424
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        dst_device,
    )
Vladimir's avatar
Vladimir committed
425
426
427
428
429

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
430
    do_opcheck = head_size == HEAD_SIZES[0]
431
432
    src_cache = src_key_caches[0]
    block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
433
434
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
435
436
437
438
439
440
        (
            src_key_caches[0],
            dist_key_caches[0],
            block_size_in_bytes,
            block_mapping_tensor,
        ),
441
442
443
444
        cond=do_opcheck,
    )
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
445
446
447
448
449
450
        (
            src_value_caches[0],
            dist_value_caches[0],
            block_size_in_bytes,
            block_mapping_tensor,
        ),
451
452
453
        cond=do_opcheck,
    )

454
455
456
457
458
459
460
461
462
463
464
465
    ops.swap_blocks(
        src_key_caches[0],
        dist_key_caches[0],
        block_size_in_bytes,
        block_mapping_tensor,
    )
    ops.swap_blocks(
        src_value_caches[0],
        dist_value_caches[0],
        block_size_in_bytes,
        block_mapping_tensor,
    )
Vladimir's avatar
Vladimir committed
466

467
    for src, dst in block_mapping:
468
469
470
471
472
473
        torch.testing.assert_close(
            src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
        )
        torch.testing.assert_close(
            src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
        )
474
475
476
477
478
479
480
481
482
483


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
484
def test_fp8_e4m3_conversion(
485
486
487
488
489
490
491
492
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
493
    set_random_seed(seed)
494
495
496
497
498
499
500
501

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
502
    ops.convert_fp8(cache_fp8, cache)
503
504

    converted_cache = torch.empty_like(cache)
505
    ops.convert_fp8(converted_cache, cache_fp8)
506

507
    torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
508
509
510
511
512
513
514
515
516
517
518


def _create_mla_cache(
    num_blocks: int,
    block_size: int,
    entry_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
) -> torch.Tensor:
    cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
519
520
521
    return torch.zeros(
        num_blocks, block_size, entry_size, dtype=cache_dtype, device=device
    )
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555


def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
    rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype

    vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
    if kv_cache_dtype == "fp8":
        temp = torch.zeros_like(cache)
        ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
        vals = temp
    cache.copy_(vals)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_concat_and_cache_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
556
    set_random_seed(seed)
557
    torch.set_default_device(device)
558
    torch.accelerator.set_device_index(device)
559
560
561

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
562
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
563
564

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
565
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
566
567
568
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
569
570
571
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
572
573
574
575
576
577
578
579
580
581
582
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
583
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
584
585
586
587
588
589
590
591
592
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

593
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
594
595
596

    if kv_cache_dtype == "fp8":
        result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
597
598
599
        ops.convert_fp8(
            result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype
        )
600
        expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
601
602
603
604
        ops.convert_fp8(
            expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype
        )
        torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
605
606
607
608
    else:
        torch.testing.assert_close(kv_cache, ref_kv_cache)


609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
628
629
    if current_platform.is_rocm():
        pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
630
631
    if dtype.itemsize != 2:
        pytest.skip("ds_mla only supports 16-bit input")
632
633
    if kv_lora_rank != 512:
        pytest.skip("fp8_ds_mla requires kv_lora_rank == 512")
634
    kv_cache_dtype = "fp8_ds_mla"
635
    set_random_seed(seed)
636
    torch.set_default_device(device)
637
    torch.accelerator.set_device_index(device)
638
639
640

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
641
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
642
643

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
644
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
645
646
647
    entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)

    scale = torch.tensor(1.0, dtype=torch.float32, device=device)
648
649
650
651
652
653
654
655
    kv_cache = _create_mla_cache(
        num_blocks,
        block_size,
        entry_size,
        dtype=torch.uint8,
        kv_cache_dtype=kv_cache_dtype,
        device=device,
    )
656
657
658
659
660
661
662
663
664
665
666
667
668
669

    ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
    tile_data = torch.zeros(128, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size

        ref_cache_slice = ref_cache[block_idx, block_offset]
        ref_cache_16bit = ref_cache_slice.view(dtype)
        ref_cache_32bit = ref_cache_slice.view(torch.float32)

        kv_c_data = kv_c[i]
670
671
        num_tiles = kv_lora_rank // 128
        for tile_idx in range(num_tiles):
672
673
674
675
676
677
678
679
680
681
682
            tile_start = tile_idx * 128
            tile_end = (tile_idx + 1) * 128
            tile_data[:] = kv_c_data[tile_start:tile_end]

            # tile_scale = tile_data.amax().to(torch.float32) / 448.
            # NOTE: Using torch's amax() gives different results,
            # so this must be manually computed.
            tile_data_float = tile_data.to(torch.float32)
            manual_max = abs(tile_data_float[0])
            for j in range(1, 128):
                manual_max = max(manual_max, abs(tile_data_float[j]))
683
            tile_scale = manual_max / 448.0
684
685
686

            ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale

687
688
689
690
691
692
            ops.convert_fp8(
                ref_cache_slice[tile_start:tile_end],
                tile_data,
                tile_scale.item(),
                kv_dtype="fp8",
            )
693
694
695
696
697
698
699
700
701
702

        for j in range(qk_rope_head_dim):
            ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

703
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
704
705
706
707
708
709
710
711
712
713

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        kv_cache_slice = kv_cache[block_idx, block_offset]
        ref_cache_slice = ref_cache[block_idx, block_offset]

        kv_nope = kv_cache_slice[:kv_lora_rank]
        ref_nope = ref_cache_slice[:kv_lora_rank]
714
715
716
717
718
719
720
721
        kv_scales = kv_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        ref_scales = ref_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
        ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
722
723
724
725
726
727

        torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)


728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_swap_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
747
    set_random_seed(seed)
748
    torch.set_default_device(device)
749
    torch.accelerator.set_device_index(device)
750
751
752

    entry_size = kv_lora_rank + qk_rope_head_dim

753
754
755
756
757
758
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
    dst_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
759
760
761
762
763
764
765
766
767
768
769

    _fill_mla_cache(src_cache, kv_cache_dtype)
    _fill_mla_cache(dst_cache, kv_cache_dtype)

    src_cache_clone = src_cache.clone()

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining_blocks, num_mappings)
    block_mapping = list(zip(src_blocks, dst_blocks))
770
771
772
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
773

774
    block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
775
776
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
777
        (src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor),
778
779
780
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

781
    ops.swap_blocks(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor)
782
783
784
785
786
787

    for src, dst in block_mapping:
        torch.testing.assert_close(
            src_cache_clone[src].cpu(),
            dst_cache[dst].cpu(),
            msg=f"Block {src} from src should have been swapped to block "
788
789
            f"{dst} in dst_cache.",
        )
790
791
792
793
794
795
796
797
798


@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
799
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
800
801
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
802
803
804
805
806
807
808
809
810
811
812
def test_gather_and_maybe_dequant_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
813
    entry_size = kv_lora_rank + qk_rope_head_dim
814
    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
815
816
817
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
818
819
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

820
821
822
    seq_len_tensor = torch.randint(
        max_seq_len, max_seq_len + 1, (batch_size,), device=device
    )
823
824

    total_tokens = seq_len_tensor.sum()
825
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
826
827
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
828
829
    token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
    token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
830
831
832
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
833
834
835
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
836
837
838
839
840

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

841
    dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
842
843
844
845
846
847
848
849
850
851
852

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
853
854
855
856
857
858
859
            block_data = src_cache[blocks[i]]
            if kv_cache_dtype == "fp8":
                dequantized_block = torch.empty_like(block_data, dtype=dtype)
                ops.convert_fp8(dequantized_block, block_data, scale.item())
                gathered_rows.append(dequantized_block)
            else:
                gathered_rows.append(block_data)
860
        remaining = s - (tot - 1) * block_size
861
862
        last_block_data = src_cache[blocks[-1], :remaining, :]
        if kv_cache_dtype == "fp8":
863
864
            dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype)
            ops.convert_fp8(dequantized_last_block, last_block_data, scale.item())
865
866
867
            gathered_rows.append(dequantized_last_block)
        else:
            gathered_rows.append(last_block_data)
868
869
870
871
872
873

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
874
        torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
875
876
877
878
879
        (
            src_cache,
            dst,
            block_table,
            cu_seq_lens,
880
881
            token_to_seq,
            total_tokens,
882
883
884
885
            kv_cache_dtype,
            scale,
            None,
        ),
886
887
888
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

889
890
891
892
893
    ops.gather_and_maybe_dequant_cache(
        src_cache,
        dst,
        block_table,
        cu_seq_lens,
894
895
        token_to_seq,
        total_tokens,
896
897
898
899
        kv_cache_dtype,
        scale,
        None,
    )
900
    torch.testing.assert_close(dst, expected)
Thien Tran's avatar
Thien Tran committed
901
902


903
904
905
906
907
908
909
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
910
911
912
@pytest.mark.parametrize(
    "kv_cache_dtype", ["auto"]
)  # You can also test "fp8" if needed.
913
914
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
915
916
917
918
919
920
921
922
923
924
925
def test_cp_gather_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
926
    entry_size = kv_lora_rank + qk_rope_head_dim
927
928
929
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
930
931
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

932
    seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
933
934

    total_tokens = seq_len_tensor.sum()
935
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
936
937
938
939
940
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
941
942
943
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
944
945
946
947
948

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

949
    dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device)
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
            gathered_rows.append(src_cache[blocks[i]])
        remaining = s - (tot - 1) * block_size
        gathered_rows.append(src_cache[blocks[-1], :remaining, :])

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
        torch.ops._C_cache_ops.cp_gather_cache,
        (src_cache, dst, block_table, cu_seq_lens, batch_size, None),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
    torch.testing.assert_close(dst, expected)


Thien Tran's avatar
Thien Tran committed
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
) -> None:
    device = "cpu"
    kv_cache_dtype = "auto"
1000
    set_random_seed(seed)
Thien Tran's avatar
Thien Tran committed
1001
1002
1003
1004
    torch.set_default_device(device)

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
1005
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
Thien Tran's avatar
Thien Tran committed
1006
1007

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
1008
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
Thien Tran's avatar
Thien Tran committed
1009
1010
1011
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
1012
1013
1014
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
Thien Tran's avatar
Thien Tran committed
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
1026
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
Thien Tran's avatar
Thien Tran committed
1027
1028
1029
1030
1031
1032
1033
1034
1035
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

1036
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
Thien Tran's avatar
Thien Tran committed
1037
    torch.testing.assert_close(kv_cache, ref_kv_cache)