test_cache.py 21.1 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import random
2
from typing import List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch

7
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS
8
from vllm import _custom_ops as ops
9
from vllm.utils import seed_everything, is_hip
10
from .utils import torch_version
Woosuk Kwon's avatar
Woosuk Kwon committed
11

Vladimir's avatar
Vladimir committed
12
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
13
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
14
NUM_TOKENS = [42]  # Arbitrary values for testing
15
NUM_LAYERS = [1]  # Arbitrary values for testing
16
NUM_HEADS = [8]  # Arbitrary values for testing
Joe's avatar
Joe committed
17
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
18
BLOCK_SIZES = [8, 16, 32]
19
20
21
22
23

# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

24
NUM_MAPPINGS = [256]  # Arbitrary values for testing
25
SEEDS = [0]
26
27
28
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
29
30

# We assume fp8 is always enabled for testing.
zhuwenwen's avatar
zhuwenwen committed
31
32
# KV_CACHE_DTYPE = ["auto", "fp8"] 
KV_CACHE_DTYPE = ["auto"] 
33
34
35
36
37
38
39
40
41
42


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
43
@pytest.mark.parametrize("device", CUDA_DEVICES)
44
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
45
@torch.inference_mode()
46
47
def test_copy_blocks(
    kv_cache_factory,
48
49
50
51
52
53
54
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
55
    seed: int,
56
    kv_cache_dtype: str,
57
    device: str,
58
) -> None:
Joe's avatar
Joe committed
59
60
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
61
    seed_everything(seed)
62
    torch.set_default_device(device)
63
64
65
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
66
67
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
68
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
69
    block_mapping: List[Tuple[int, int]] = []
70
    for i in range(num_mappings):
71
72
73
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
74
75
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
76
77
78
79

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
80
                                                head_size, kv_cache_dtype,
81
                                                dtype, seed, device)
82
83
84
85

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
86
87

    # Call the copy blocks kernel.
88
89
90
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    
    if torch_version.startswith("2.3"):
        ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
        for src, dst in block_mapping:
            for cloned_key_cache in cloned_key_caches:
                cloned_key_cache[dst].copy_(cloned_key_cache[src])
            for cloned_value_cache in cloned_value_caches:
                cloned_value_cache[dst].copy_(cloned_value_cache[src])

        # Compare the results.
        for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
            torch.allclose(key_cache, cloned_key_cache)
        for value_cache, cloned_value_cache in zip(value_caches,
                                                cloned_value_caches):
            assert torch.allclose(value_cache, cloned_value_cache)
    
    elif torch_version.startswith("2.4"):
        from tests.kernels.utils import opcheck
        opcheck(torch.ops._C_cache_ops.copy_blocks,
                (key_caches, value_caches, block_mapping_tensor),
                test_utils=DEFAULT_OPCHECK_TEST_UTILS,
                cond=(head_size == HEAD_SIZES[0]))
        ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)

        # Run the reference implementation.
        for src, dst in block_mapping:
            for cloned_key_cache in cloned_key_caches:
                cloned_key_cache[dst].copy_(cloned_key_cache[src])
            for cloned_value_cache in cloned_value_caches:
                cloned_value_cache[dst].copy_(cloned_value_cache[src])

        # Compare the results.
        for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
            torch.testing.assert_close(key_cache, cloned_key_cache)
        for value_cache, cloned_value_cache in zip(value_caches,
                                                cloned_value_caches):
            torch.testing.assert_close(value_cache, cloned_value_cache)
    else:
        print(f"PyTorch version {torch_version} is not specifically handled.")
130
131


132
133
134
135
136
137
138
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
139
@pytest.mark.parametrize("device", CUDA_DEVICES)
140
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
141
@torch.inference_mode()
142
143
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146
147
148
149
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
150
    seed: int,
151
    device: str,
152
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
153
) -> None:
Joe's avatar
Joe committed
154
155
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
156
    seed_everything(seed)
157
    torch.set_default_device(device)
158
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
159
    num_slots = block_size * num_blocks
160
161
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
162
163

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
    _, key, value = qkv.unbind(dim=1)

166
167
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
168
169
170
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
171
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
172

173
    # Clone the KV caches.
174
175
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
176
        ops.convert_fp8(cloned_key_cache, key_cache)
177
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
178
        ops.convert_fp8(cloned_value_cache, value_cache)
179
180
181
182
183
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
184
    k_scale = v_scale = 1.0
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    
    if torch_version.startswith("2.3"):
        
        ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
                            kv_cache_dtype, k_scale,v_scale)

        if kv_cache_dtype == "fp8":
            result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
            ops.convert_fp8(result_key_cache, key_cache)
            result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
            ops.convert_fp8(result_value_cache, value_cache)

        # Run the reference implementation.
        reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
        block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
        block_indicies = block_indicies.cpu().tolist()
        block_offsets = slot_mapping % block_size
        block_offsets = block_offsets.cpu().tolist()
        for i in range(num_tokens):
            block_idx = block_indicies[i]
            block_offset = block_offsets[i]
            cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
            cloned_value_cache[block_idx, :, :, block_offset] = value[i]

        if kv_cache_dtype == "fp8":
            assert torch.allclose(result_key_cache,
                                cloned_key_cache,
                                atol=0.001,
                                rtol=0.1)
            assert torch.allclose(result_value_cache,
                                cloned_value_cache,
                                atol=0.001,
                                rtol=0.1)
        else:
            assert torch.allclose(key_cache, cloned_key_cache)
            assert torch.allclose(value_cache, cloned_value_cache)
            
    elif torch_version.startswith("2.4"):
        from tests.kernels.utils import opcheck

        # Call the reshape_and_cache kernel.
        opcheck(torch.ops._C_cache_ops.reshape_and_cache,
                (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
                k_scale, v_scale),
                cond=(head_size == HEAD_SIZES[0]))
        ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
                            kv_cache_dtype, k_scale, v_scale)

        if kv_cache_dtype == "fp8":
            result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
            ops.convert_fp8(result_key_cache, key_cache)
            result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
            ops.convert_fp8(result_value_cache, value_cache)

        # Run the reference implementation.
        reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
        block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
        block_indicies_lst = block_indicies.cpu().tolist()
        block_offsets = slot_mapping % block_size
        block_offsets_lst = block_offsets.cpu().tolist()
        for i in range(num_tokens):
            block_idx = block_indicies_lst[i]
            block_offset = block_offsets_lst[i]
            cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
            cloned_value_cache[block_idx, :, :, block_offset] = value[i]

        if kv_cache_dtype == "fp8":
            torch.testing.assert_close(result_key_cache,
                                    cloned_key_cache,
                                    atol=0.001,
                                    rtol=0.1)
            torch.testing.assert_close(result_value_cache,
                                    cloned_value_cache,
                                    atol=0.001,
                                    rtol=0.1)
        else:
            torch.testing.assert_close(key_cache, cloned_key_cache)
            torch.testing.assert_close(value_cache, cloned_value_cache)
263
    else:
264
265
        print(f"PyTorch version {torch_version} is not specifically handled.")
        
Vladimir's avatar
Vladimir committed
266

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
289
    seed_everything(seed)
290
    torch.set_default_device(device)
291
292
293

    # Create a random slot mapping.
    num_slots = block_size * num_blocks
294
295
296
297
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device=device)
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
316
        device=device,
317
    )
318
319
320
321
    key_cache, value_cache = key_caches[0].contiguous(
    ), value_caches[0].contiguous()
    del key_caches
    del value_caches
322
323

    # Clone the KV caches.
324
325
326
327
328
329
330
331
332
333
334
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
        ops.convert_fp8(cloned_key_cache, key_cache)
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
        ops.convert_fp8(cloned_value_cache, value_cache)
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
    k_scale = v_scale = 1.0
335

336
337
338
339
    if torch_version.startswith("2.3"):
        # Clone the KV caches.
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()
340

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        # Call the reshape_and_cache kernel.
        ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
                                    slot_mapping, kv_cache_dtype, k_scale, v_scale)

        # Run the reference implementation.
        block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
        block_indicies = block_indicies.cpu().tolist()
        block_offsets = slot_mapping % block_size
        block_offsets = block_offsets.cpu().tolist()
        for i in range(num_tokens):
            block_idx = block_indicies[i]
            block_offset = block_offsets[i]
            cloned_key_cache[block_idx, block_offset, :, :] = key[i]
            cloned_value_cache[block_idx, block_offset, :, :] = value[i]

        assert torch.allclose(key_cache, cloned_key_cache)
        assert torch.allclose(value_cache, cloned_value_cache)
    elif torch_version.startswith("2.4"):
        from tests.kernels.utils import opcheck
        # Call the reshape_and_cache kernel.
        opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
                (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
                k_scale, v_scale),
                cond=(head_size == HEAD_SIZES[0]))
        ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
                                    slot_mapping, kv_cache_dtype, k_scale, v_scale)

        if kv_cache_dtype == "fp8":
            result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
            ops.convert_fp8(result_key_cache, key_cache)
            result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
            ops.convert_fp8(result_value_cache, value_cache)

        # Run the reference implementation.
        block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
        block_indicies_lst = block_indicies.cpu().tolist()
        block_offsets = slot_mapping % block_size
        block_offsets_lst = block_offsets.cpu().tolist()
        for i in range(num_tokens):
            block_idx = block_indicies_lst[i]
            block_offset = block_offsets_lst[i]
            cloned_key_cache[block_idx, block_offset, :, :] = key[i]
            cloned_value_cache[block_idx, block_offset, :, :] = value[i]

        if kv_cache_dtype == "fp8":
            torch.testing.assert_close(result_key_cache,
                                    cloned_key_cache,
                                    atol=0.001,
                                    rtol=0.1)
            torch.testing.assert_close(result_value_cache,
                                    cloned_value_cache,
                                    atol=0.001,
                                    rtol=0.1)
        else:
            torch.testing.assert_close(key_cache, cloned_key_cache)
            torch.testing.assert_close(value_cache, cloned_value_cache)
397
    else:
398
        print(f"PyTorch version {torch_version} is not specifically handled.")
399
400


Vladimir's avatar
Vladimir committed
401
402
403
404
405
406
407
408
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
409
@pytest.mark.parametrize("device", CUDA_DEVICES)
410
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
411
412
413
414
415
416
417
418
419
420
421
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
422
    device: str,
423
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
424
) -> None:
425
426
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
427
428
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
429
430

    seed_everything(seed)
431
432
433

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
434
435
436
437
438
439
440
441
442

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

443
444
445
446
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)
Vladimir's avatar
Vladimir committed
447
448
449

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
450
451
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
452
453
454

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
455
456
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
457
458
459
460

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
    if torch_version.startswith("2.3"):
        # Call the swap_blocks kernel.
        ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                        block_mapping_tensor)
        ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                        block_mapping_tensor)

        for src, dst in block_mapping:
            assert torch.allclose(src_key_caches_clone[src].cpu(),
                                dist_key_caches[0][dst].cpu())
            assert torch.allclose(src_value_caches_clone[src].cpu(),
                                dist_value_caches[0][dst].cpu())
    elif torch_version.startswith("2.4"):
        from tests.kernels.utils import opcheck       
        # Call the swap_blocks kernel.
        do_opcheck = (head_size == HEAD_SIZES[0])
        opcheck(torch.ops._C_cache_ops.swap_blocks,
                (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
                cond=do_opcheck)
        opcheck(torch.ops._C_cache_ops.swap_blocks,
                (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
                cond=do_opcheck)

        ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                        block_mapping_tensor)
        ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                        block_mapping_tensor)

        for src, dst in block_mapping:
            torch.testing.assert_close(src_key_caches_clone[src].cpu(),
                                    dist_key_caches[0][dst].cpu())
            torch.testing.assert_close(src_value_caches_clone[src].cpu(),
                                    dist_value_caches[0][dst].cpu())
    else:
        print(f"PyTorch version {torch_version} is not specifically handled.")


@pytest.mark.skipif(is_hip(),
                    reason="FP8 is not supported on ROCm.")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_e4m3_conversion(
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    seed_everything(seed)

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
    ops.convert_fp8(cache_fp8, cache)

    converted_cache = torch.empty_like(cache)
    ops.convert_fp8(converted_cache, cache_fp8)

    torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)