test_cache.py 15.9 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import random
2
from typing import List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch

7
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
8
from vllm import _custom_ops as ops
9
from vllm.utils import seed_everything
Woosuk Kwon's avatar
Woosuk Kwon committed
10

Vladimir's avatar
Vladimir committed
11
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
12
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
13
NUM_TOKENS = [42]  # Arbitrary values for testing
14
NUM_LAYERS = [1]  # Arbitrary values for testing
15
NUM_HEADS = [8]  # Arbitrary values for testing
Joe's avatar
Joe committed
16
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
17
BLOCK_SIZES = [8, 16, 32]
18
19
20
21
22

# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

23
NUM_MAPPINGS = [256]  # Arbitrary values for testing
24
SEEDS = [0]
25
26
27
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
28
29

# We assume fp8 is always enabled for testing.
zhuwenwen's avatar
zhuwenwen committed
30
31
# KV_CACHE_DTYPE = ["auto", "fp8"] 
KV_CACHE_DTYPE = ["auto"] 
32
33
34
35
36
37
38
39
40
41


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
42
@pytest.mark.parametrize("device", CUDA_DEVICES)
43
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
44
@torch.inference_mode()
45
46
def test_copy_blocks(
    kv_cache_factory,
47
48
49
50
51
52
53
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
54
    seed: int,
55
    kv_cache_dtype: str,
56
    device: str,
57
) -> None:
Joe's avatar
Joe committed
58
59
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
60
    seed_everything(seed)
61
    torch.set_default_device(device)
62
63
64
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
65
66
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
67
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
68
    block_mapping: List[Tuple[int, int]] = []
69
    for i in range(num_mappings):
70
71
72
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
73
74
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
75
76
77
78

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
79
                                                head_size, kv_cache_dtype,
80
                                                dtype, seed, device)
81
82
83
84

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
85
86

    # Call the copy blocks kernel.
87
88
89
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)
90
91
92
93
94

    opcheck(torch.ops._C_cache_ops.copy_blocks,
            (key_caches, value_caches, block_mapping_tensor),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS,
            cond=(head_size == HEAD_SIZES[0]))
95
    ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
96

97
    # Run the reference implementation.
98
99
100
101
102
    for src, dst in block_mapping:
        for cloned_key_cache in cloned_key_caches:
            cloned_key_cache[dst].copy_(cloned_key_cache[src])
        for cloned_value_cache in cloned_value_caches:
            cloned_value_cache[dst].copy_(cloned_value_cache[src])
103
104
105

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
106
        torch.testing.assert_close(key_cache, cloned_key_cache)
107
108
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
109
        torch.testing.assert_close(value_cache, cloned_value_cache)
110
111


112
113
114
115
116
117
118
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
119
@pytest.mark.parametrize("device", CUDA_DEVICES)
120
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
121
@torch.inference_mode()
122
123
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
126
127
128
129
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
130
    seed: int,
131
    device: str,
132
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
133
) -> None:
Joe's avatar
Joe committed
134
135
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
136
    seed_everything(seed)
137
    torch.set_default_device(device)
138
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
139
    num_slots = block_size * num_blocks
140
141
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
142
143

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
    _, key, value = qkv.unbind(dim=1)

146
147
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
148
149
150
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
151
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
152

153
    # Clone the KV caches.
154
155
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
156
        ops.convert_fp8(cloned_key_cache, key_cache)
157
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
158
        ops.convert_fp8(cloned_value_cache, value_cache)
159
160
161
162
163
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
164
    k_scale = v_scale = 1.0
Woosuk Kwon's avatar
Woosuk Kwon committed
165

166
    # Call the reshape_and_cache kernel.
167
168
169
170
    opcheck(torch.ops._C_cache_ops.reshape_and_cache,
            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
             k_scale, v_scale),
            cond=(head_size == HEAD_SIZES[0]))
171
    ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
172
                          kv_cache_dtype, k_scale, v_scale)
173
174
175

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
176
        ops.convert_fp8(result_key_cache, key_cache)
177
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
178
        ops.convert_fp8(result_value_cache, value_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
179

180
181
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
182
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
183
    block_indicies_lst = block_indicies.cpu().tolist()
184
    block_offsets = slot_mapping % block_size
185
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
186
    for i in range(num_tokens):
187
188
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
189
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
190
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
191

192
    if kv_cache_dtype == "fp8":
193
194
195
196
197
198
199
200
        torch.testing.assert_close(result_key_cache,
                                   cloned_key_cache,
                                   atol=0.001,
                                   rtol=0.1)
        torch.testing.assert_close(result_value_cache,
                                   cloned_value_cache,
                                   atol=0.001,
                                   rtol=0.1)
201
    else:
202
203
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
204
205


206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
228
    seed_everything(seed)
229
    torch.set_default_device(device)
230
231
232

    # Create a random slot mapping.
    num_slots = block_size * num_blocks
233
234
235
236
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254

    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device=device)
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
255
        device=device,
256
    )
257
258
259
260
    key_cache, value_cache = key_caches[0].contiguous(
    ), value_caches[0].contiguous()
    del key_caches
    del value_caches
261
262

    # Clone the KV caches.
263
264
265
266
267
268
269
270
271
272
273
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
        ops.convert_fp8(cloned_key_cache, key_cache)
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
        ops.convert_fp8(cloned_value_cache, value_cache)
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
    k_scale = v_scale = 1.0
274
275

    # Call the reshape_and_cache kernel.
276
277
278
279
    opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
             k_scale, v_scale),
            cond=(head_size == HEAD_SIZES[0]))
280
    ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
281
282
283
284
285
286
287
                                slot_mapping, kv_cache_dtype, k_scale, v_scale)

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
        ops.convert_fp8(result_key_cache, key_cache)
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
        ops.convert_fp8(result_value_cache, value_cache)
288
289

    # Run the reference implementation.
290
291
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indicies_lst = block_indicies.cpu().tolist()
292
    block_offsets = slot_mapping % block_size
293
    block_offsets_lst = block_offsets.cpu().tolist()
294
    for i in range(num_tokens):
295
296
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
297
298
299
        cloned_key_cache[block_idx, block_offset, :, :] = key[i]
        cloned_value_cache[block_idx, block_offset, :, :] = value[i]

300
    if kv_cache_dtype == "fp8":
301
302
303
304
305
306
307
308
        torch.testing.assert_close(result_key_cache,
                                   cloned_key_cache,
                                   atol=0.001,
                                   rtol=0.1)
        torch.testing.assert_close(result_value_cache,
                                   cloned_value_cache,
                                   atol=0.001,
                                   rtol=0.1)
309
    else:
310
311
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
312
313


Vladimir's avatar
Vladimir committed
314
315
316
317
318
319
320
321
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
322
@pytest.mark.parametrize("device", CUDA_DEVICES)
323
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
324
325
326
327
328
329
330
331
332
333
334
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
335
    device: str,
336
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
337
) -> None:
338
339
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
340
341
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
342
343

    seed_everything(seed)
344
345
346

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
347
348
349
350
351
352
353
354
355

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

356
357
358
359
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)
Vladimir's avatar
Vladimir committed
360
361
362

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
363
364
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
365
366
367

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
368
369
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
370
371
372
373
374

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
375
376
377
378
379
380
381
382
    do_opcheck = (head_size == HEAD_SIZES[0])
    opcheck(torch.ops._C_cache_ops.swap_blocks,
            (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
            cond=do_opcheck)
    opcheck(torch.ops._C_cache_ops.swap_blocks,
            (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
            cond=do_opcheck)

383
384
385
386
    ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                    block_mapping_tensor)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                    block_mapping_tensor)
Vladimir's avatar
Vladimir committed
387

388
    for src, dst in block_mapping:
389
390
391
392
        torch.testing.assert_close(src_key_caches_clone[src].cpu(),
                                   dist_key_caches[0][dst].cpu())
        torch.testing.assert_close(src_value_caches_clone[src].cpu(),
                                   dist_value_caches[0][dst].cpu())
393
394


zhuwenwen's avatar
zhuwenwen committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_fp8_e4m3_conversion(
#     num_heads: int,
#     head_size: int,
#     block_size: int,
#     num_blocks: int,
#     dtype: torch.dtype,
#     seed: int,
#     device: str,
# ) -> None:
zhuwenwen's avatar
zhuwenwen committed
412
#     seed_everything(seed)
zhuwenwen's avatar
zhuwenwen committed
413
414
415
416
417
418
419
420
421
422
423
424
425

#     low = -224.0
#     high = 224.0
#     shape = (num_blocks, num_heads, head_size, block_size)
#     cache = torch.empty(shape, dtype=dtype, device=device)
#     cache.uniform_(low, high)

#     cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
#     ops.convert_fp8(cache_fp8, cache)

#     converted_cache = torch.empty_like(cache)
#     ops.convert_fp8(converted_cache, cache_fp8)

426
#     torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)