test_cache.py 35.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import random

6
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch

9
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
from vllm import _custom_ops as ops
11
from vllm.platforms import current_platform
Woosuk Kwon's avatar
Woosuk Kwon committed
12

Vladimir's avatar
Vladimir committed
13
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
14
DTYPES = [torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
15
NUM_TOKENS = [42]  # Arbitrary values for testing
16
NUM_LAYERS = [1]  # Arbitrary values for testing
17
NUM_HEADS = [8]  # Arbitrary values for testing
18
HEAD_SIZES = [64, 80, 256]
19
BLOCK_SIZES = [8, 16, 32]
20
CACHE_LAYOUTS = ["NHD", "HND"]
21

22
23
24
25
26
27
28
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
QK_ROPE_HEAD_DIMS = [64]
NUM_TOKENS_MLA = [42]
BLOCK_SIZES_MLA = [16]
NUM_BLOCKS_MLA = [8]

29
30
31
32
# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

33
NUM_MAPPINGS = [256]  # Arbitrary values for testing
34
SEEDS = [0]
35
36
37
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
38
39

# We assume fp8 is always enabled for testing.
40
KV_CACHE_DTYPE = ["auto", "fp8"]
41

42
43
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]

44
45
46
47
48
49
50
51
52

@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
53
@pytest.mark.parametrize("device", CUDA_DEVICES)
54
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
55
@torch.inference_mode()
56
57
def test_copy_blocks(
    kv_cache_factory,
58
59
60
61
62
63
64
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
65
    seed: int,
66
    kv_cache_dtype: str,
67
    device: str,
68
) -> None:
Joe's avatar
Joe committed
69
70
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
71
    current_platform.seed_everything(seed)
72
    torch.set_default_device(device)
73
74
75
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
76
    src_blocks = random.sample(range(num_blocks), num_mappings)
77
78
    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining_blocks, 2 * num_mappings)
79
    block_mapping: list[tuple[int, int]] = []
80
    for i in range(num_mappings):
81
82
83
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
84
85
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
86
87
88
89

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
90
                                                head_size, kv_cache_dtype,
91
                                                dtype, seed, device)
92
93
94
95

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
96
97

    # Call the copy blocks kernel.
98
99
100
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)
101
102
103
104
105

    opcheck(torch.ops._C_cache_ops.copy_blocks,
            (key_caches, value_caches, block_mapping_tensor),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS,
            cond=(head_size == HEAD_SIZES[0]))
106
    ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
107

108
    # Run the reference implementation.
109
110
111
112
113
    for src, dst in block_mapping:
        for cloned_key_cache in cloned_key_caches:
            cloned_key_cache[dst].copy_(cloned_key_cache[src])
        for cloned_value_cache in cloned_value_caches:
            cloned_value_cache[dst].copy_(cloned_value_cache[src])
114
115
116

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
117
        torch.testing.assert_close(key_cache, cloned_key_cache)
118
119
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
120
        torch.testing.assert_close(value_cache, cloned_value_cache)
121
122


123
124
125
126
127
128
129
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
130
@pytest.mark.parametrize("device", CUDA_DEVICES)
131
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
132
@torch.inference_mode()
133
134
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
138
139
140
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
141
    seed: int,
142
    device: str,
143
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
144
) -> None:
Joe's avatar
Joe committed
145
146
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
147
    current_platform.seed_everything(seed)
148
    torch.set_default_device(device)
149
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
150
    num_slots = block_size * num_blocks
151
152
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
153
154

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
    _, key, value = qkv.unbind(dim=1)

157
158
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
159
160
161
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
162
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
163

164
165
166
167
    # Using default kv_scale
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)

168
    # Clone the KV caches.
169
170
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
171
        ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
172
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
173
        ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
174
175
176
177
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

178
    # Call the reshape_and_cache kernel.
179
180
181
182
    opcheck(torch.ops._C_cache_ops.reshape_and_cache,
            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
             k_scale, v_scale),
            cond=(head_size == HEAD_SIZES[0]))
183
    ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
184
                          kv_cache_dtype, k_scale, v_scale)
185
186
187

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
188
        ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
189
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
190
        ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
Woosuk Kwon's avatar
Woosuk Kwon committed
191

192
193
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
194
195
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
196
    block_offsets = slot_mapping % block_size
197
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
198
    for i in range(num_tokens):
199
        block_idx = block_indices_lst[i]
200
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
201
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
202
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
203

204
    if kv_cache_dtype == "fp8":
205
206
207
208
209
210
211
212
        torch.testing.assert_close(result_key_cache,
                                   cloned_key_cache,
                                   atol=0.001,
                                   rtol=0.1)
        torch.testing.assert_close(result_value_cache,
                                   cloned_value_cache,
                                   atol=0.001,
                                   rtol=0.1)
213
    else:
214
215
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
216
217


218
219
220
221
222
223
224
225
226
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
227
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
228
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
229
230
231
232
233
234
235
236
237
238
239
240
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
241
    kv_cache_layout: str,
242
    implementation: str,
243
) -> None:
244
    current_platform.seed_everything(seed)
245
    torch.set_default_device(device)
246
247
248
    assert implementation in ["cuda", "triton"]
    if implementation == "triton" and kv_cache_layout == "HND":
        pytest.skip("Triton implementation only supports NHD layout.")
249

250
251
252
253
    # fp8 conversion requires continugous memory buffer. Reduce the number of
    # blocks and tokens to consume less memory.
    num_tokens = num_tokens // 2
    num_blocks = num_blocks // 2
254
255
    # Create a random slot mapping.
    num_slots = block_size * num_blocks
256
257
258
259
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device=device)
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
277
        device=device,
278
        cache_layout=kv_cache_layout,
279
    )
280
    key_cache, value_cache = key_caches[0], value_caches[0]
281
282
    del key_caches
    del value_caches
283

284
285
    k_scale = (key.amax() / 64.0).to(torch.float32)
    v_scale = (value.amax() / 64.0).to(torch.float32)
286

287
288
289
290
291
292
293
    def permute_and_compact(x):
        y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
        return y.contiguous()

    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)

294
    # Clone the KV caches.
295
    if kv_cache_dtype == "fp8":
296
297
298
        cloned_key_cache = torch.empty_like(key_cache_compact,
                                            dtype=torch.float16)
        ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
299
                        kv_cache_dtype)
300
301
302
303
        cloned_value_cache = torch.empty_like(value_cache_compact,
                                              dtype=torch.float16)
        ops.convert_fp8(cloned_value_cache, value_cache_compact,
                        v_scale.item(), kv_cache_dtype)
304
    else:
305
306
        cloned_key_cache = key_cache_compact.clone()
        cloned_value_cache = value_cache_compact.clone()
307
    # Call the reshape_and_cache kernel.
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    if implementation == "cuda":
        opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
                (key, value, key_cache, value_cache, slot_mapping,
                 kv_cache_dtype, k_scale, v_scale),
                cond=(head_size == HEAD_SIZES[0]))
        ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
                                    slot_mapping, kv_cache_dtype, k_scale,
                                    v_scale)
    elif implementation == "triton":
        from vllm.attention.ops.triton_reshape_and_cache_flash import (
            triton_reshape_and_cache_flash)
        triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
                                       slot_mapping, kv_cache_dtype, k_scale,
                                       v_scale)
322
323
    key_cache_compact = permute_and_compact(key_cache)
    value_cache_compact = permute_and_compact(value_cache)
324
325

    if kv_cache_dtype == "fp8":
326
327
        result_key_cache = torch.empty_like(key_cache_compact,
                                            dtype=torch.float16)
328
        ops.convert_fp8(result_key_cache,
329
                        key_cache_compact,
330
                        k_scale.item(),
331
                        kv_dtype=kv_cache_dtype)
332
333
        result_value_cache = torch.empty_like(value_cache_compact,
                                              dtype=torch.float16)
334
        ops.convert_fp8(result_value_cache,
335
                        value_cache_compact,
336
                        v_scale.item(),
337
                        kv_dtype=kv_cache_dtype)
338
339

    # Run the reference implementation.
340
341
    block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indices_lst = block_indices.cpu().tolist()
342
    block_offsets = slot_mapping % block_size
343
    block_offsets_lst = block_offsets.cpu().tolist()
344
    for i in range(num_tokens):
345
        block_idx = block_indices_lst[i]
346
        block_offset = block_offsets_lst[i]
347
348
349
350
351
352
        if kv_cache_layout == "NHD":
            cloned_key_cache[block_idx, block_offset, :, :] = key[i]
            cloned_value_cache[block_idx, block_offset, :, :] = value[i]
        else:
            cloned_key_cache[block_idx, :, block_offset, :] = key[i]
            cloned_value_cache[block_idx, :, block_offset, :] = value[i]
353

354
    if kv_cache_dtype == "fp8":
355
356
357
358
359
360
361
362
        torch.testing.assert_close(result_key_cache,
                                   cloned_key_cache,
                                   atol=0.001,
                                   rtol=0.1)
        torch.testing.assert_close(result_value_cache,
                                   cloned_value_cache,
                                   atol=0.001,
                                   rtol=0.1)
363
    else:
364
365
        torch.testing.assert_close(key_cache_compact, cloned_key_cache)
        torch.testing.assert_close(value_cache_compact, cloned_value_cache)
366
367


Vladimir's avatar
Vladimir committed
368
369
370
371
372
373
374
375
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
376
@pytest.mark.parametrize("device", CUDA_DEVICES)
377
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
378
379
380
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
381
    direction: tuple[str, str],
Vladimir's avatar
Vladimir committed
382
383
384
385
386
387
388
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
389
    device: str,
390
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
391
) -> None:
392
393
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
394
395
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
396

397
    current_platform.seed_everything(seed)
398
399
400

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
401
402
403
404
405
406
407
408
409

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

410
411
412
413
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)
Vladimir's avatar
Vladimir committed
414
415
416

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
417
418
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
419
420
421

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
422
423
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
424
425
426
427
428

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
429
430
431
432
433
434
435
436
    do_opcheck = (head_size == HEAD_SIZES[0])
    opcheck(torch.ops._C_cache_ops.swap_blocks,
            (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
            cond=do_opcheck)
    opcheck(torch.ops._C_cache_ops.swap_blocks,
            (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
            cond=do_opcheck)

437
438
439
440
    ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                    block_mapping_tensor)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                    block_mapping_tensor)
Vladimir's avatar
Vladimir committed
441

442
    for src, dst in block_mapping:
443
444
445
446
        torch.testing.assert_close(src_key_caches_clone[src].cpu(),
                                   dist_key_caches[0][dst].cpu())
        torch.testing.assert_close(src_value_caches_clone[src].cpu(),
                                   dist_value_caches[0][dst].cpu())
447
448
449
450
451
452
453
454
455
456


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
457
def test_fp8_e4m3_conversion(
458
459
460
461
462
463
464
465
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
466
    current_platform.seed_everything(seed)
467
468
469
470
471
472
473
474

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
475
    ops.convert_fp8(cache_fp8, cache)
476
477

    converted_cache = torch.empty_like(cache)
478
    ops.convert_fp8(converted_cache, cache_fp8)
479

480
    torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
481
482
483
484
485
486
487
488
489
490
491


def _create_mla_cache(
    num_blocks: int,
    block_size: int,
    entry_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str,
    device: str,
) -> torch.Tensor:
    cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
492
493
494
495
496
    return torch.zeros(num_blocks,
                       block_size,
                       entry_size,
                       dtype=cache_dtype,
                       device=device)
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548


def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
    rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype

    vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
    if kv_cache_dtype == "fp8":
        temp = torch.zeros_like(cache)
        ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
        vals = temp
    cache.copy_(vals)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_concat_and_cache_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
    k_pe = torch.randn(num_tokens,
                       qk_rope_head_dim,
                       dtype=dtype,
                       device=device)
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
    kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
549
                                 kv_cache_dtype, device)
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
        ops.convert_fp8(ref_kv_cache,
                        ref_temp,
                        scale.item(),
                        kv_dtype=kv_cache_dtype)
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
                             kv_cache_dtype, scale)

    if kv_cache_dtype == "fp8":
        result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
        ops.convert_fp8(result_temp,
                        kv_cache.contiguous(),
                        scale.item(),
                        kv_dtype=kv_cache_dtype)
        expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
        ops.convert_fp8(expected_temp,
                        ref_kv_cache,
                        scale.item(),
                        kv_dtype=kv_cache_dtype)
        torch.testing.assert_close(result_temp,
                                   expected_temp,
                                   atol=0.001,
                                   rtol=0.1)
    else:
        torch.testing.assert_close(kv_cache, ref_kv_cache)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_copy_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    num_layers: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    entry_size = kv_lora_rank + qk_rope_head_dim

    kv_caches = []
    for _ in range(num_layers):
        kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
625
                                     kv_cache_dtype, device)
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
        _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
        kv_caches.append(kv_cache)

    ref_caches = [kv_cache.clone() for kv_cache in kv_caches]

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining, 2 * num_mappings)
    block_mapping = []
    for i in range(num_mappings):
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)

    for src, dst in block_mapping:
        for ref_cache in ref_caches:
            ref_cache[dst].copy_(ref_cache[src])

    opcheck(
        torch.ops._C_cache_ops.copy_blocks_mla,
        (kv_caches, block_mapping_tensor),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )
    ops.copy_blocks_mla(kv_caches, block_mapping_tensor)

    for kv_cache, ref_cache in zip(kv_caches, ref_caches):
        torch.testing.assert_close(kv_cache, ref_cache)


@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_swap_blocks_mla(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    entry_size = kv_lora_rank + qk_rope_head_dim

    src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
686
                                  kv_cache_dtype, device)
687
    dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
688
                                  kv_cache_dtype, device)
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717

    _fill_mla_cache(src_cache, kv_cache_dtype)
    _fill_mla_cache(dst_cache, kv_cache_dtype)

    src_cache_clone = src_cache.clone()

    num_mappings = min(2, num_blocks // 2)
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remaining_blocks, num_mappings)
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)

    opcheck(
        torch.ops._C_cache_ops.swap_blocks,
        (src_cache, dst_cache, block_mapping_tensor),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)

    for src, dst in block_mapping:
        torch.testing.assert_close(
            src_cache_clone[src].cpu(),
            dst_cache[dst].cpu(),
            msg=f"Block {src} from src should have been swapped to block "
            f"{dst} in dst_cache.")
718
719
720
721
722
723
724
725
726


@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
727
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
728
729
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
730
731
732
733
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
                                            block_size, num_blocks,
                                            max_seq_len, batch_size, dtype,
                                            kv_cache_dtype, device):
734
    entry_size = kv_lora_rank + qk_rope_head_dim
735
    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
736
    src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
737
                                  kv_cache_dtype, device)
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

    seq_len_tensor = torch.randint(0,
                                   max_seq_len + 1, (batch_size, ),
                                   device=device)

    total_tokens = seq_len_tensor.sum()
    cu_seq_lens = torch.empty((batch_size + 1),
                              dtype=torch.int32,
                              device=device)
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
    block_table = torch.empty((batch_size, num_blocks),
                              dtype=torch.int32,
                              device=device)

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

761
    dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
762
763
764
765
766
767
768
769
770
771
772

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
773
774
775
776
777
778
779
            block_data = src_cache[blocks[i]]
            if kv_cache_dtype == "fp8":
                dequantized_block = torch.empty_like(block_data, dtype=dtype)
                ops.convert_fp8(dequantized_block, block_data, scale.item())
                gathered_rows.append(dequantized_block)
            else:
                gathered_rows.append(block_data)
780
        remaining = s - (tot - 1) * block_size
781
782
783
784
785
786
787
788
789
        last_block_data = src_cache[blocks[-1], :remaining, :]
        if kv_cache_dtype == "fp8":
            dequantized_last_block = torch.empty_like(last_block_data,
                                                      dtype=dtype)
            ops.convert_fp8(dequantized_last_block, last_block_data,
                            scale.item())
            gathered_rows.append(dequantized_last_block)
        else:
            gathered_rows.append(last_block_data)
790
791
792
793
794
795

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
796
797
798
        torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
        (src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
         scale, None),
799
800
801
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

802
803
804
    ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
                                       cu_seq_lens, batch_size, kv_cache_dtype,
                                       scale, None)
805
    torch.testing.assert_close(dst, expected)
Thien Tran's avatar
Thien Tran committed
806
807


808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
                         ["auto"])  # You can also test "fp8" if needed.
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
                             num_blocks, max_seq_len, batch_size, dtype,
                             kv_cache_dtype, device):
    entry_size = kv_lora_rank + qk_rope_head_dim
    src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
                                  kv_cache_dtype, device)
    _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

    seq_len_tensor = torch.randint(0,
                                   max_seq_len + 1, (batch_size, ),
                                   device=device)

    total_tokens = seq_len_tensor.sum()
    cu_seq_lens = torch.empty((batch_size + 1),
                              dtype=torch.int32,
                              device=device)
    cu_seq_lens[0] = 0
    cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
    print("seq_len_tensor", seq_len_tensor)

    tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
    block_table = torch.empty((batch_size, num_blocks),
                              dtype=torch.int32,
                              device=device)

    for b in range(batch_size):
        perm = torch.randperm(num_blocks, device=device)
        block_table[b, :] = perm

    dst = torch.zeros((total_tokens, entry_size),
                      dtype=src_cache.dtype,
                      device=device)

    expected_batches = []
    for b in range(batch_size):
        s = seq_len_tensor[b]
        if s == 0:
            continue
        tot = tot_blocks_tensor[b]
        blocks = block_table[b, :tot].tolist()

        gathered_rows = []
        for i in range(tot - 1):
            gathered_rows.append(src_cache[blocks[i]])
        remaining = s - (tot - 1) * block_size
        gathered_rows.append(src_cache[blocks[-1], :remaining, :])

        batch_expected = torch.cat(gathered_rows, dim=0)
        expected_batches.append(batch_expected)
    expected = torch.cat(expected_batches, dim=0)

    opcheck(
        torch.ops._C_cache_ops.cp_gather_cache,
        (src_cache, dst, block_table, cu_seq_lens, batch_size, None),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
    torch.testing.assert_close(dst, expected)


Thien Tran's avatar
Thien Tran committed
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
    kv_lora_rank: int,
    qk_rope_head_dim: int,
    num_tokens: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
) -> None:
    device = "cpu"
    kv_cache_dtype = "auto"
    current_platform.seed_everything(seed)
    torch.set_default_device(device)

    total_slots = num_blocks * block_size
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)

    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
    k_pe = torch.randn(num_tokens,
                       qk_rope_head_dim,
                       dtype=dtype,
                       device=device)
    entry_size = kv_lora_rank + qk_rope_head_dim

    scale = torch.tensor(0.1, dtype=torch.float32, device=device)
    kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
                                 kv_cache_dtype, device)
    ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)

    for i in range(num_tokens):
        slot = slot_mapping[i].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
        ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

    if kv_cache_dtype == "fp8":
        ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
        ops.convert_fp8(ref_kv_cache,
                        ref_temp,
                        scale.item(),
                        kv_dtype=kv_cache_dtype)
    else:
        ref_kv_cache = ref_temp

    opcheck(
        torch.ops._C_cache_ops.concat_and_cache_mla,
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )

    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
                             kv_cache_dtype, scale)
    torch.testing.assert_close(kv_cache, ref_kv_cache)