test_cache.py 16.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Woosuk Kwon's avatar
Woosuk Kwon committed
3
import random
4
from typing import List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
5

6
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch

9
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
from vllm import _custom_ops as ops
11
from vllm.platforms import current_platform
Woosuk Kwon's avatar
Woosuk Kwon committed
12

Vladimir's avatar
Vladimir committed
13
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
14
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
15
NUM_TOKENS = [42]  # Arbitrary values for testing
16
NUM_LAYERS = [1]  # Arbitrary values for testing
17
NUM_HEADS = [8]  # Arbitrary values for testing
18
HEAD_SIZES = [64, 80, 120, 256]
19
BLOCK_SIZES = [8, 16, 32]
20
21
22
23
24

# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

25
NUM_MAPPINGS = [256]  # Arbitrary values for testing
26
SEEDS = [0]
27
28
29
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
30
31

# We assume fp8 is always enabled for testing.
32
KV_CACHE_DTYPE = ["auto", "fp8"]
33
34
35
36
37
38
39
40
41
42


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
43
@pytest.mark.parametrize("device", CUDA_DEVICES)
44
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
45
@torch.inference_mode()
46
47
def test_copy_blocks(
    kv_cache_factory,
48
49
50
51
52
53
54
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
55
    seed: int,
56
    kv_cache_dtype: str,
57
    device: str,
58
) -> None:
Joe's avatar
Joe committed
59
60
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
61
    current_platform.seed_everything(seed)
62
    torch.set_default_device(device)
63
64
65
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
66
67
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
68
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
69
    block_mapping: List[Tuple[int, int]] = []
70
    for i in range(num_mappings):
71
72
73
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
74
75
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
76
77
78
79

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
80
                                                head_size, kv_cache_dtype,
81
                                                dtype, seed, device)
82
83
84
85

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
86
87

    # Call the copy blocks kernel.
88
89
90
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)
91
92
93
94
95

    opcheck(torch.ops._C_cache_ops.copy_blocks,
            (key_caches, value_caches, block_mapping_tensor),
            test_utils=DEFAULT_OPCHECK_TEST_UTILS,
            cond=(head_size == HEAD_SIZES[0]))
96
    ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
97

98
    # Run the reference implementation.
99
100
101
102
103
    for src, dst in block_mapping:
        for cloned_key_cache in cloned_key_caches:
            cloned_key_cache[dst].copy_(cloned_key_cache[src])
        for cloned_value_cache in cloned_value_caches:
            cloned_value_cache[dst].copy_(cloned_value_cache[src])
104
105
106

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
107
        torch.testing.assert_close(key_cache, cloned_key_cache)
108
109
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
110
        torch.testing.assert_close(value_cache, cloned_value_cache)
111
112


113
114
115
116
117
118
119
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
120
@pytest.mark.parametrize("device", CUDA_DEVICES)
121
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
122
@torch.inference_mode()
123
124
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
128
129
130
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
131
    seed: int,
132
    device: str,
133
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
134
) -> None:
Joe's avatar
Joe committed
135
136
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
137
    current_platform.seed_everything(seed)
138
    torch.set_default_device(device)
139
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
140
    num_slots = block_size * num_blocks
141
142
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
143
144

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
    _, key, value = qkv.unbind(dim=1)

147
148
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
149
150
151
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
152
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
153

154
    # Clone the KV caches.
155
156
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
157
        ops.convert_fp8(cloned_key_cache, key_cache)
158
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
159
        ops.convert_fp8(cloned_value_cache, value_cache)
160
161
162
163
164
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
165
    k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
Woosuk Kwon's avatar
Woosuk Kwon committed
166

167
    # Call the reshape_and_cache kernel.
168
169
170
171
    opcheck(torch.ops._C_cache_ops.reshape_and_cache,
            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
             k_scale, v_scale),
            cond=(head_size == HEAD_SIZES[0]))
172
    ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
173
                          kv_cache_dtype, k_scale, v_scale)
174
175
176

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
177
        ops.convert_fp8(result_key_cache, key_cache)
178
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
179
        ops.convert_fp8(result_value_cache, value_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
180

181
182
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
183
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
184
    block_indicies_lst = block_indicies.cpu().tolist()
185
    block_offsets = slot_mapping % block_size
186
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
187
    for i in range(num_tokens):
188
189
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
190
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
191
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
192

193
    if kv_cache_dtype == "fp8":
194
195
196
197
198
199
200
201
        torch.testing.assert_close(result_key_cache,
                                   cloned_key_cache,
                                   atol=0.001,
                                   rtol=0.1)
        torch.testing.assert_close(result_value_cache,
                                   cloned_value_cache,
                                   atol=0.001,
                                   rtol=0.1)
202
    else:
203
204
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
205
206


207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
229
    current_platform.seed_everything(seed)
230
    torch.set_default_device(device)
231
232
233

    # Create a random slot mapping.
    num_slots = block_size * num_blocks
234
235
236
237
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device=device)
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
256
        device=device,
257
    )
258
259
260
261
    key_cache, value_cache = key_caches[0].contiguous(
    ), value_caches[0].contiguous()
    del key_caches
    del value_caches
262

263
264
    k_scale = (key.amax() / 256.0).to(torch.float32)
    v_scale = (value.amax() / 256.0).to(torch.float32)
265

266
    # Clone the KV caches.
267
268
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
269
        ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
270
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
271
272
        ops.convert_fp8(cloned_value_cache, value_cache, v_scale,
                        kv_cache_dtype)
273
274
275
276
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

277
    # Call the reshape_and_cache kernel.
278
279
280
281
    opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
             k_scale, v_scale),
            cond=(head_size == HEAD_SIZES[0]))
282
    ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
283
284
285
286
                                slot_mapping, kv_cache_dtype, k_scale, v_scale)

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
287
288
        ops.convert_fp8(result_key_cache,
                        key_cache,
289
                        k_scale.item(),
290
                        kv_dtype=kv_cache_dtype)
291
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
292
293
        ops.convert_fp8(result_value_cache,
                        value_cache,
294
                        v_scale.item(),
295
                        kv_dtype=kv_cache_dtype)
296
297

    # Run the reference implementation.
298
299
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indicies_lst = block_indicies.cpu().tolist()
300
    block_offsets = slot_mapping % block_size
301
    block_offsets_lst = block_offsets.cpu().tolist()
302
    for i in range(num_tokens):
303
304
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
305
306
307
        cloned_key_cache[block_idx, block_offset, :, :] = key[i]
        cloned_value_cache[block_idx, block_offset, :, :] = value[i]

308
    if kv_cache_dtype == "fp8":
309
310
311
312
313
314
315
316
        torch.testing.assert_close(result_key_cache,
                                   cloned_key_cache,
                                   atol=0.001,
                                   rtol=0.1)
        torch.testing.assert_close(result_value_cache,
                                   cloned_value_cache,
                                   atol=0.001,
                                   rtol=0.1)
317
    else:
318
319
        torch.testing.assert_close(key_cache, cloned_key_cache)
        torch.testing.assert_close(value_cache, cloned_value_cache)
320
321


Vladimir's avatar
Vladimir committed
322
323
324
325
326
327
328
329
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
330
@pytest.mark.parametrize("device", CUDA_DEVICES)
331
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
332
333
334
335
336
337
338
339
340
341
342
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
343
    device: str,
344
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
345
) -> None:
346
347
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
348
349
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
350

351
    current_platform.seed_everything(seed)
352
353
354

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
355
356
357
358
359
360
361
362
363

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

364
365
366
367
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)
Vladimir's avatar
Vladimir committed
368
369
370

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
371
372
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
373
374
375

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
376
377
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
378
379
380
381
382

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
383
384
385
386
387
388
389
390
    do_opcheck = (head_size == HEAD_SIZES[0])
    opcheck(torch.ops._C_cache_ops.swap_blocks,
            (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
            cond=do_opcheck)
    opcheck(torch.ops._C_cache_ops.swap_blocks,
            (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
            cond=do_opcheck)

391
392
393
394
    ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                    block_mapping_tensor)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                    block_mapping_tensor)
Vladimir's avatar
Vladimir committed
395

396
    for src, dst in block_mapping:
397
398
399
400
        torch.testing.assert_close(src_key_caches_clone[src].cpu(),
                                   dist_key_caches[0][dst].cpu())
        torch.testing.assert_close(src_value_caches_clone[src].cpu(),
                                   dist_value_caches[0][dst].cpu())
401
402
403
404
405
406
407
408
409
410


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
411
def test_fp8_e4m3_conversion(
412
413
414
415
416
417
418
419
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
420
    current_platform.seed_everything(seed)
421
422
423
424
425
426
427
428

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
429
    ops.convert_fp8(cache_fp8, cache)
430
431

    converted_cache = torch.empty_like(cache)
432
    ops.convert_fp8(converted_cache, cache_fp8)
433

434
    torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)