test_cache.py 16.3 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import random
2
from typing import List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch

7
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
8
from vllm import _custom_ops as ops
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Vladimir's avatar
Vladimir committed
10
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
11
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
12
NUM_TOKENS = [42]  # Arbitrary values for testing
13
NUM_LAYERS = [1]  # Arbitrary values for testing
14
NUM_HEADS = [8]  # Arbitrary values for testing
Joe's avatar
Joe committed
15
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
16
BLOCK_SIZES = [8, 16, 32]
17
18
19
20
21

# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

22
NUM_MAPPINGS = [256]  # Arbitrary values for testing
23
SEEDS = [0]
24
25
26
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
27
28

# We assume fp8 is always enabled for testing.
zhuwenwen's avatar
zhuwenwen committed
29
30
# KV_CACHE_DTYPE = ["auto", "fp8"] 
KV_CACHE_DTYPE = ["auto"] 
31
32
33
34
35
36
37
38
39
40


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
41
@pytest.mark.parametrize("device", CUDA_DEVICES)
42
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
43
@torch.inference_mode()
44
45
def test_copy_blocks(
    kv_cache_factory,
46
47
48
49
50
51
52
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
53
    seed: int,
54
    kv_cache_dtype: str,
55
    device: str,
56
) -> None:
Joe's avatar
Joe committed
57
58
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
59
60
    random.seed(seed)
    torch.random.manual_seed(seed)
61
62
63
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
64
65
66
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
67
68
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
69
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
70
    block_mapping: List[Tuple[int, int]] = []
71
    for i in range(num_mappings):
72
73
74
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
75
76
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
77
78
79
80

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
81
                                                head_size, kv_cache_dtype,
82
                                                dtype, seed, device)
83
84
85
86

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
87
88

    # Call the copy blocks kernel.
89
90
91
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)
92
93
94
95
96

    opcheck(torch.ops._C_cache_ops.copy_blocks,
            (key_caches, value_caches, block_mapping_tensor),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS,
            cond=(head_size == HEAD_SIZES[0]))
97
    ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
98

99
    # Run the reference implementation.
100
101
102
103
104
    for src, dst in block_mapping:
        for cloned_key_cache in cloned_key_caches:
            cloned_key_cache[dst].copy_(cloned_key_cache[src])
        for cloned_value_cache in cloned_value_caches:
            cloned_value_cache[dst].copy_(cloned_value_cache[src])
105
106
107

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
108
        torch.testing.assert_close(key_cache, cloned_key_cache)
109
110
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
111
        torch.testing.assert_close(value_cache, cloned_value_cache)
112
113


114
115
116
117
118
119
120
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
121
@pytest.mark.parametrize("device", CUDA_DEVICES)
122
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
123
@torch.inference_mode()
124
125
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
129
130
131
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
132
    seed: int,
133
    device: str,
134
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
) -> None:
Joe's avatar
Joe committed
136
137
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
138
139
    random.seed(seed)
    torch.random.manual_seed(seed)
140
141
142
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
143
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
144
    num_slots = block_size * num_blocks
145
146
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
147
148

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
    _, key, value = qkv.unbind(dim=1)

151
152
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
153
154
155
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
156
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
157

158
    # Clone the KV caches.
159
160
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
161
        ops.convert_fp8(cloned_key_cache, key_cache)
162
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
163
        ops.convert_fp8(cloned_value_cache, value_cache)
164
165
166
167
168
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
169
    k_scale = v_scale = 1.0
Woosuk Kwon's avatar
Woosuk Kwon committed
170

171
    # Call the reshape_and_cache kernel.
172
173
174
175
    opcheck(torch.ops._C_cache_ops.reshape_and_cache,
            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
             k_scale, v_scale),
            cond=(head_size == HEAD_SIZES[0]))
176
    ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
177
                          kv_cache_dtype, k_scale, v_scale)
178
179
180

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
181
        ops.convert_fp8(result_key_cache, key_cache)
182
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
183
        ops.convert_fp8(result_value_cache, value_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
184

185
186
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
187
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
188
    block_indicies_lst = block_indicies.cpu().tolist()
189
    block_offsets = slot_mapping % block_size
190
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
191
    for i in range(num_tokens):
192
193
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
194
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
195
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
196

197
    if kv_cache_dtype == "fp8":
198
199
200
201
202
203
204
205
        torch.testing.assert_close(result_key_cache,
                                   cloned_key_cache,
                                   atol=0.001,
                                   rtol=0.1)
        torch.testing.assert_close(result_value_cache,
                                   cloned_value_cache,
                                   atol=0.001,
                                   rtol=0.1)
206
    else:
207
208
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
209
210


211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
236
    torch.set_default_device(device)
237
238
239

    # Create a random slot mapping.
    num_slots = block_size * num_blocks
240
241
242
243
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device=device)
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
262
        device=device,
263
    )
264
265
266
267
    key_cache, value_cache = key_caches[0].contiguous(
    ), value_caches[0].contiguous()
    del key_caches
    del value_caches
268
269

    # Clone the KV caches.
270
271
272
273
274
275
276
277
278
279
280
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
        ops.convert_fp8(cloned_key_cache, key_cache)
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
        ops.convert_fp8(cloned_value_cache, value_cache)
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
    k_scale = v_scale = 1.0
281
282

    # Call the reshape_and_cache kernel.
283
284
285
286
    opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
             k_scale, v_scale),
            cond=(head_size == HEAD_SIZES[0]))
287
    ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
288
289
290
291
292
293
294
                                slot_mapping, kv_cache_dtype, k_scale, v_scale)

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
        ops.convert_fp8(result_key_cache, key_cache)
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
        ops.convert_fp8(result_value_cache, value_cache)
295
296

    # Run the reference implementation.
297
298
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indicies_lst = block_indicies.cpu().tolist()
299
    block_offsets = slot_mapping % block_size
300
    block_offsets_lst = block_offsets.cpu().tolist()
301
    for i in range(num_tokens):
302
303
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
304
305
306
        cloned_key_cache[block_idx, block_offset, :, :] = key[i]
        cloned_value_cache[block_idx, block_offset, :, :] = value[i]

307
    if kv_cache_dtype == "fp8":
308
309
310
311
312
313
314
315
        torch.testing.assert_close(result_key_cache,
                                   cloned_key_cache,
                                   atol=0.001,
                                   rtol=0.1)
        torch.testing.assert_close(result_value_cache,
                                   cloned_value_cache,
                                   atol=0.001,
                                   rtol=0.1)
316
    else:
317
318
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
319
320


Vladimir's avatar
Vladimir committed
321
322
323
324
325
326
327
328
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
329
@pytest.mark.parametrize("device", CUDA_DEVICES)
330
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
331
332
333
334
335
336
337
338
339
340
341
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
342
    device: str,
343
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
344
) -> None:
345
346
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
347
348
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
Vladimir's avatar
Vladimir committed
349
350
    random.seed(seed)
    torch.random.manual_seed(seed)
351
352
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
353
354
355

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
356
357
358
359
360
361
362
363
364

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

365
366
367
368
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)
Vladimir's avatar
Vladimir committed
369
370
371

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
372
373
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
374
375
376

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
377
378
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
379
380
381
382
383

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
384
385
386
387
388
389
390
391
    do_opcheck = (head_size == HEAD_SIZES[0])
    opcheck(torch.ops._C_cache_ops.swap_blocks,
            (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
            cond=do_opcheck)
    opcheck(torch.ops._C_cache_ops.swap_blocks,
            (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
            cond=do_opcheck)

392
393
394
395
    ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                    block_mapping_tensor)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                    block_mapping_tensor)
Vladimir's avatar
Vladimir committed
396

397
    for src, dst in block_mapping:
398
399
400
401
        torch.testing.assert_close(src_key_caches_clone[src].cpu(),
                                   dist_key_caches[0][dst].cpu())
        torch.testing.assert_close(src_value_caches_clone[src].cpu(),
                                   dist_value_caches[0][dst].cpu())
402
403


zhuwenwen's avatar
zhuwenwen committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_fp8_e4m3_conversion(
#     num_heads: int,
#     head_size: int,
#     block_size: int,
#     num_blocks: int,
#     dtype: torch.dtype,
#     seed: int,
#     device: str,
# ) -> None:
#     random.seed(seed)
#     torch.random.manual_seed(seed)
#     torch.cuda.manual_seed(seed)

#     low = -224.0
#     high = 224.0
#     shape = (num_blocks, num_heads, head_size, block_size)
#     cache = torch.empty(shape, dtype=dtype, device=device)
#     cache.uniform_(low, high)

#     cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
#     ops.convert_fp8(cache_fp8, cache)

#     converted_cache = torch.empty_like(cache)
#     ops.convert_fp8(converted_cache, cache_fp8)

437
#     torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)