"vscode:/vscode.git/clone" did not exist on "00e6402d56fb258e6958381b1f3ceb34217ba830"
test_cache.py 34.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import random

6
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch

9
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
from vllm import _custom_ops as ops
11
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
12
from vllm.platforms import current_platform
13
from vllm.utils.torch_utils import set_random_seed
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
16
DTYPES = [torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
17
NUM_TOKENS = [42]  # Arbitrary values for testing
18
NUM_LAYERS = [1]  # Arbitrary values for testing
19
NUM_HEADS = [8]  # Arbitrary values for testing
20
HEAD_SIZES = [64, 80, 256]
21
BLOCK_SIZES = [8, 16, 32]
22
CACHE_LAYOUTS = ["NHD", "HND"]
23
KV_SCALE_TYPES = ["tensor", "attn_head"]
24

25
26
27
28
29
30
31
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
QK_ROPE_HEAD_DIMS = [64]
NUM_TOKENS_MLA = [42]
BLOCK_SIZES_MLA = [16]
NUM_BLOCKS_MLA = [8]

32
33
34
35
# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

36
NUM_MAPPINGS = [256]  # Arbitrary values for testing
37
SEEDS = [0]
38
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
39
40

# We assume fp8 is always enabled for testing.
41
KV_CACHE_DTYPE = ["auto", "fp8"]
42

43
44
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]

45
46
47
48
49
50
51
52

@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
53
@pytest.mark.parametrize("device", CUDA_DEVICES)
54
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
55
@torch.inference_mode()
56
57
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
61
62
63
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
64
    seed: int,
65
    device: str,
66
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
67
) -> None:
Joe's avatar
Joe committed
68
69
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
70
    set_random_seed(seed)
71
    torch.set_default_device(device)
72
    torch.cuda.set_device(device)
73
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
74
    num_slots = block_size * num_blocks
75
76
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
77
78

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
    _, key, value = qkv.unbind(dim=1)

81
    # Create the KV caches.
82
83
84
85
86
87
88
89
90
91
92
    key_caches, value_caches = kv_cache_factory(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        device,
    )
93
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
94

95
96
97
98
    # Using default kv_scale
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)

99
    # Clone the KV caches.
100
101
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
102
        ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
103
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
104
        ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
105
106
107
108
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

109
    # Call the reshape_and_cache kernel.
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    opcheck(
        torch.ops._C_cache_ops.reshape_and_cache,
        (
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        ),
        cond=(head_size == HEAD_SIZES[0]),
    )
    ops.reshape_and_cache(
        key,
        value,
        key_cache,
        value_cache,
        slot_mapping,
        kv_cache_dtype,
        k_scale,
        v_scale,
    )
134
135
136

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
137
        ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
138
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
139
        ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
Woosuk Kwon's avatar
Woosuk Kwon committed
140

141
142
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
143
144
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
145
    block_offsets = slot_mapping % block_size
146
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
147
    for i in range(num_tokens):
148
        block_idx = block_indices_lst[i]
149
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
150
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
151
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
152

153
    if kv_cache_dtype == "fp8":
154
155
156
157
158
159
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
160
    else:
161
162
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
163
164


165
166
167
168
169
170
171
172
173
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
174
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
175
@pytest.mark.parametrize("kv_scale_type", KV_SCALE_TYPES)
176
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
177
178
179
180
181
182
183
184
185
186
187
188
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
189
    kv_cache_layout: str,
190
    kv_scale_type: str,
191
    implementation: str,
192
) -> None:
193
    set_random_seed(seed)
194
    torch.set_default_device(device)
195
    torch.cuda.set_device(device)
196
197
198
    assert implementation in ["cuda", "triton"]
    if implementation == "triton" and kv_cache_layout == "HND":
        pytest.skip("Triton implementation only supports NHD layout.")
199

200
201
202
    if kv_scale_type == "attn_head" and implementation != "cuda":
        pytest.skip("Only CUDA implementation supports attn_head scaling.")

203
204
205
206
    # fp8 conversion requires continugous memory buffer. Reduce the number of
    # blocks and tokens to consume less memory.
    num_tokens = num_tokens // 2
    num_blocks = num_blocks // 2
207
208
    # Create a random slot mapping.
    num_slots = block_size * num_blocks
209
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
210
211
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
212
213
214
215
216
217
218
219
220
221
222
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
223
        device=device,
224
        cache_layout=kv_cache_layout,
225
    )
226
    key_cache, value_cache = key_caches[0], value_caches[0]
227
228
    del key_caches
    del value_caches
229

230
231
232
233
234
235
    if kv_scale_type == "tensor":
        k_scale = (key.amax() / 64.0).to(torch.float32)
        v_scale = (value.amax() / 64.0).to(torch.float32)
    else:  # "attn_head"
        k_scale = (key.amax(dim=(0, 2)) / 64.0).to(torch.float32)
        v_scale = (value.amax(dim=(0, 2)) / 64.0).to(torch.float32)
236

237
238
239
240
241
242
243
    def permute_and_compact(x):
        y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
        return y.contiguous()

    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)

244
245
246
247
248
249
250
251
252
253
254
255
256
257
    def convert_fp8_local(output, input, scale, kv_dtype):
        fp8_input = input.view(torch.float8_e4m3fn)
        if scale.numel() == 1:  # per-tensor
            result = scaled_dequantize(
                fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype
            ).reshape(*input.shape)
        else:  # per-head: broadcast scale along the head dimension
            # Original code uses dim 2 for NHD, dim 1 for HND
            if kv_cache_layout == "NHD":
                result = fp8_input.to(output.dtype) * scale.view(1, 1, -1, 1)
            else:
                result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1)
        output.copy_(result)

258
    # Clone the KV caches.
259
    if kv_cache_dtype == "fp8":
260
        cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
261
        convert_fp8_local(cloned_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
262
        cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
263
264
        convert_fp8_local(
            cloned_value_cache, value_cache_compact, v_scale, kv_cache_dtype
265
        )
266
    else:
267
268
        cloned_key_cache = key_cache_compact.clone()
        cloned_value_cache = value_cache_compact.clone()
269
    # Call the reshape_and_cache kernel.
270
    if implementation == "cuda":
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        opcheck(
            torch.ops._C_cache_ops.reshape_and_cache_flash,
            (
                key,
                value,
                key_cache,
                value_cache,
                slot_mapping,
                kv_cache_dtype,
                k_scale,
                v_scale,
            ),
            cond=(head_size == HEAD_SIZES[0]),
        )
        ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
295
    elif implementation == "triton":
296
        from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
297
298
299
300
301
302
303
304
305
306
307
308
309
            triton_reshape_and_cache_flash,
        )

        triton_reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale,
            v_scale,
        )
310
311
    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)
312
313

    if kv_cache_dtype == "fp8":
314
        result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
315
        convert_fp8_local(result_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
316
        result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
317
        convert_fp8_local(
318
319
            result_value_cache,
            value_cache_compact,
320
321
            v_scale,
            kv_cache_dtype,
322
        )
323
324

    # Run the reference implementation.
325
326
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
327
    block_offsets = slot_mapping % block_size
328
    block_offsets_lst = block_offsets.cpu().tolist()
329
    for i in range(num_tokens):
330
        block_idx = block_indices_lst[i]
331
        block_offset = block_offsets_lst[i]
332
333
334
335
336
337
        if kv_cache_layout == "NHD":
            cloned_key_cache[block_idx, block_offset, :, :] = key[i]
            cloned_value_cache[block_idx, block_offset, :, :] = value[i]
        else:
            cloned_key_cache[block_idx, :, block_offset, :] = key[i]
            cloned_value_cache[block_idx, :, block_offset, :] = value[i]
338

339
    if kv_cache_dtype == "fp8":
340
341
342
343
344
345
        torch.testing.assert_close(
            result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
        )
        torch.testing.assert_close(
            result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
        )
346
    else:
347
348
        torch.testing.assert_close(key_cache_compact, cloned_key_cache)
        torch.testing.assert_close(value_cache_compact, cloned_value_cache)
349
350


Vladimir's avatar
Vladimir committed
351
352
353
354
355
356
357
358
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
359
@pytest.mark.parametrize("device", CUDA_DEVICES)
360
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
361
362
363
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
364
    direction: tuple[str, str],
Vladimir's avatar
Vladimir committed
365
366
367
368
369
370
371
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
372
    device: str,
373
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
374
) -> None:
375
376
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
377
378
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
379

380
    set_random_seed(seed)
381

382
383
    src_device = device if direction[0] == "cuda" else "cpu"
    dst_device = device if direction[1] == "cuda" else "cpu"
Vladimir's avatar
Vladimir committed
384
385
386
387
388
389
390
391
392

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

393
    block_mapping = list(zip(src_blocks, dst_blocks))
394
395
396
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
Vladimir's avatar
Vladimir committed
397
398
399

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
400
401
402
403
404
405
406
407
408
409
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        src_device,
    )
Vladimir's avatar
Vladimir committed
410
411
412

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
413
414
415
416
417
418
419
420
421
422
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
        seed,
        dst_device,
    )
Vladimir's avatar
Vladimir committed
423
424
425
426
427

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
428
    do_opcheck = head_size == HEAD_SIZES[0]
429
430
    src_cache = src_key_caches[0]
    block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
431
432
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
433
434
435
436
437
438
        (
            src_key_caches[0],
            dist_key_caches[0],
            block_size_in_bytes,
            block_mapping_tensor,
        ),
439
440
441
442
        cond=do_opcheck,
    )
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
443
444
445
446
447
448
        (
            src_value_caches[0],
            dist_value_caches[0],
            block_size_in_bytes,
            block_mapping_tensor,
        ),
449
450
451
        cond=do_opcheck,
    )

452
453
454
455
456
457
458
459
460
461
462
463
    ops.swap_blocks(
        src_key_caches[0],
        dist_key_caches[0],
        block_size_in_bytes,
        block_mapping_tensor,
    )
    ops.swap_blocks(
        src_value_caches[0],
        dist_value_caches[0],
        block_size_in_bytes,
        block_mapping_tensor,
    )
Vladimir's avatar
Vladimir committed
464

465
    for src, dst in block_mapping:
466
467
468
469
470
471
        torch.testing.assert_close(
            src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
        )
        torch.testing.assert_close(
            src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
        )
472
473
474
475
476
477
478
479
480
481


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
482
def test_fp8_e4m3_conversion(
483
484
485
486
487
488
489
490
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
491
    set_random_seed(seed)
492
493
494
495
496
497
498
499

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
500
    ops.convert_fp8(cache_fp8, cache)
501
502

    converted_cache = torch.empty_like(cache)
503
    ops.convert_fp8(converted_cache, cache_fp8)
504

505
    torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
506
507
508
509
510
511
512
513
514
515
516


def _create_mla_cache(
    num_blocks: int,
    block_size: int,
    entry_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
) -> torch.Tensor:
    cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
517
518
519
    return torch.zeros(
        num_blocks, block_size, entry_size, dtype=cache_dtype, device=device
    )
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553


def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
    rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype

    vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
    if kv_cache_dtype == "fp8":
        temp = torch.zeros_like(cache)
        ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
        vals = temp
    cache.copy_(vals)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_concat_and_cache_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
554
    set_random_seed(seed)
555
    torch.set_default_device(device)
556
    torch.cuda.set_device(device)
557
558
559

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
560
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
561
562

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
563
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
564
565
566
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
567
568
569
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
570
571
572
573
574
575
576
577
578
579
580
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
581
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
582
583
584
585
586
587
588
589
590
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

591
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
592
593
594

    if kv_cache_dtype == "fp8":
        result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
595
596
597
        ops.convert_fp8(
            result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype
        )
598
        expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
599
600
601
602
        ops.convert_fp8(
            expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype
        )
        torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
603
604
605
606
    else:
        torch.testing.assert_close(kv_cache, ref_kv_cache)


607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
626
627
    if current_platform.is_rocm():
        pytest.skip("concat_and_cache_mla doesn't support fp8_ds_mla on ROCm")
628
629
630
    if dtype.itemsize != 2:
        pytest.skip("ds_mla only supports 16-bit input")
    kv_cache_dtype = "fp8_ds_mla"
631
    set_random_seed(seed)
632
    torch.set_default_device(device)
633
    torch.cuda.set_device(device)
634
635
636

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
637
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
638
639

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
640
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
641
642
643
    entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)

    scale = torch.tensor(1.0, dtype=torch.float32, device=device)
644
645
646
647
648
649
650
651
    kv_cache = _create_mla_cache(
        num_blocks,
        block_size,
        entry_size,
        dtype=torch.uint8,
        kv_cache_dtype=kv_cache_dtype,
        device=device,
    )
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677

    ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
    tile_data = torch.zeros(128, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size

        ref_cache_slice = ref_cache[block_idx, block_offset]
        ref_cache_16bit = ref_cache_slice.view(dtype)
        ref_cache_32bit = ref_cache_slice.view(torch.float32)

        kv_c_data = kv_c[i]
        for tile_idx in range(4):
            tile_start = tile_idx * 128
            tile_end = (tile_idx + 1) * 128
            tile_data[:] = kv_c_data[tile_start:tile_end]

            # tile_scale = tile_data.amax().to(torch.float32) / 448.
            # NOTE: Using torch's amax() gives different results,
            # so this must be manually computed.
            tile_data_float = tile_data.to(torch.float32)
            manual_max = abs(tile_data_float[0])
            for j in range(1, 128):
                manual_max = max(manual_max, abs(tile_data_float[j]))
678
            tile_scale = manual_max / 448.0
679
680
681

            ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale

682
683
684
685
686
687
            ops.convert_fp8(
                ref_cache_slice[tile_start:tile_end],
                tile_data,
                tile_scale.item(),
                kv_dtype="fp8",
            )
688
689
690
691
692
693
694
695
696
697

        for j in range(qk_rope_head_dim):
            ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

698
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
699
700
701
702
703
704
705
706
707
708

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        kv_cache_slice = kv_cache[block_idx, block_offset]
        ref_cache_slice = ref_cache[block_idx, block_offset]

        kv_nope = kv_cache_slice[:kv_lora_rank]
        ref_nope = ref_cache_slice[:kv_lora_rank]
709
710
711
712
713
714
715
716
        kv_scales = kv_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        ref_scales = ref_cache_slice.view(torch.float32)[
            kv_lora_rank // 4 : kv_lora_rank // 4 + 4
        ]
        kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
        ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
717
718
719
720
721
722

        torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
        torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)


723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_swap_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
742
    set_random_seed(seed)
743
    torch.set_default_device(device)
744
    torch.cuda.set_device(device)
745
746
747

    entry_size = kv_lora_rank + qk_rope_head_dim

748
749
750
751
752
753
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
    dst_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
754
755
756
757
758
759
760
761
762
763
764

    _fill_mla_cache(src_cache, kv_cache_dtype)
    _fill_mla_cache(dst_cache, kv_cache_dtype)

    src_cache_clone = src_cache.clone()

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining_blocks, num_mappings)
    block_mapping = list(zip(src_blocks, dst_blocks))
765
766
767
    block_mapping_tensor = torch.tensor(
        block_mapping, dtype=torch.int64, device="cpu"
    ).view(-1, 2)
768

769
    block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
770
771
    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
772
        (src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor),
773
774
775
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

776
    ops.swap_blocks(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor)
777
778
779
780
781
782

    for src, dst in block_mapping:
        torch.testing.assert_close(
            src_cache_clone[src].cpu(),
            dst_cache[dst].cpu(),
            msg=f"Block {src} from src should have been swapped to block "
783
784
            f"{dst} in dst_cache.",
        )
785
786
787
788
789
790
791
792
793


@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
794
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
795
796
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
797
798
799
800
801
802
803
804
805
806
807
def test_gather_and_maybe_dequant_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
808
    entry_size = kv_lora_rank + qk_rope_head_dim
809
    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
810
811
812
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
813
814
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

815
816
817
    seq_len_tensor = torch.randint(
        max_seq_len, max_seq_len + 1, (batch_size,), device=device
    )
818
819

    total_tokens = seq_len_tensor.sum()
820
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
821
822
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
823
824
    token_to_seq = torch.arange(0, batch_size, dtype=torch.int32, device=device)
    token_to_seq = torch.repeat_interleave(token_to_seq, seq_len_tensor)
825
826
827
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
828
829
830
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
831
832
833
834
835

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

836
    dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
837
838
839
840
841
842
843
844
845
846
847

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
848
849
850
851
852
853
854
            block_data = src_cache[blocks[i]]
            if kv_cache_dtype == "fp8":
                dequantized_block = torch.empty_like(block_data, dtype=dtype)
                ops.convert_fp8(dequantized_block, block_data, scale.item())
                gathered_rows.append(dequantized_block)
            else:
                gathered_rows.append(block_data)
855
        remaining = s - (tot - 1) * block_size
856
857
        last_block_data = src_cache[blocks[-1], :remaining, :]
        if kv_cache_dtype == "fp8":
858
859
            dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype)
            ops.convert_fp8(dequantized_last_block, last_block_data, scale.item())
860
861
862
            gathered_rows.append(dequantized_last_block)
        else:
            gathered_rows.append(last_block_data)
863
864
865
866
867
868

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
869
        torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
870
871
872
873
874
        (
            src_cache,
            dst,
            block_table,
            cu_seq_lens,
875
876
            token_to_seq,
            total_tokens,
877
878
879
880
            kv_cache_dtype,
            scale,
            None,
        ),
881
882
883
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

884
885
886
887
888
    ops.gather_and_maybe_dequant_cache(
        src_cache,
        dst,
        block_table,
        cu_seq_lens,
889
890
        token_to_seq,
        total_tokens,
891
892
893
894
        kv_cache_dtype,
        scale,
        None,
    )
895
    torch.testing.assert_close(dst, expected)
Thien Tran's avatar
Thien Tran committed
896
897


898
899
900
901
902
903
904
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
905
906
907
@pytest.mark.parametrize(
    "kv_cache_dtype", ["auto"]
)  # You can also test "fp8" if needed.
908
909
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
910
911
912
913
914
915
916
917
918
919
920
def test_cp_gather_cache_mla(
    kv_lora_rank,
    qk_rope_head_dim,
    block_size,
    num_blocks,
    max_seq_len,
    batch_size,
    dtype,
    kv_cache_dtype,
    device,
):
921
    entry_size = kv_lora_rank + qk_rope_head_dim
922
923
924
    src_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
925
926
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

927
    seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
928
929

    total_tokens = seq_len_tensor.sum()
930
    cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
931
932
933
934
935
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
936
937
938
    block_table = torch.empty(
        (batch_size, num_blocks), dtype=torch.int32, device=device
    )
939
940
941
942
943

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

944
    dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device)
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
            gathered_rows.append(src_cache[blocks[i]])
        remaining = s - (tot - 1) * block_size
        gathered_rows.append(src_cache[blocks[-1], :remaining, :])

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
        torch.ops._C_cache_ops.cp_gather_cache,
        (src_cache, dst, block_table, cu_seq_lens, batch_size, None),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
    torch.testing.assert_close(dst, expected)


Thien Tran's avatar
Thien Tran committed
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
) -> None:
    device = "cpu"
    kv_cache_dtype = "auto"
995
    set_random_seed(seed)
Thien Tran's avatar
Thien Tran committed
996
997
998
999
    torch.set_default_device(device)

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
1000
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
Thien Tran's avatar
Thien Tran committed
1001
1002

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
1003
    k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
Thien Tran's avatar
Thien Tran committed
1004
1005
1006
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
1007
1008
1009
    kv_cache = _create_mla_cache(
        num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
    )
Thien Tran's avatar
Thien Tran committed
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
1021
        ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
Thien Tran's avatar
Thien Tran committed
1022
1023
1024
1025
1026
1027
1028
1029
1030
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

1031
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
Thien Tran's avatar
Thien Tran committed
1032
    torch.testing.assert_close(kv_cache, ref_kv_cache)