test_cache.py 15.1 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import random
2
from typing import List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch

7
from vllm import _custom_ops as ops
Woosuk Kwon's avatar
Woosuk Kwon committed
8

Vladimir's avatar
Vladimir committed
9
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
10
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
11
NUM_TOKENS = [42]  # Arbitrary values for testing
12
NUM_LAYERS = [1]  # Arbitrary values for testing
13
NUM_HEADS = [8]  # Arbitrary values for testing
Joe's avatar
Joe committed
14
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
15
BLOCK_SIZES = [8, 16, 32]
16
17
18
19
20

# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

21
NUM_MAPPINGS = [256]  # Arbitrary values for testing
22
SEEDS = [0]
23
24
25
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
26
27

# We assume fp8 is always enabled for testing.
zhuwenwen's avatar
zhuwenwen committed
28
29
# KV_CACHE_DTYPE = ["auto", "fp8"] 
KV_CACHE_DTYPE = ["auto"] 
30
31
32
33
34
35
36
37
38
39


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
40
@pytest.mark.parametrize("device", CUDA_DEVICES)
41
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
42
@torch.inference_mode()
43
44
def test_copy_blocks(
    kv_cache_factory,
45
46
47
48
49
50
51
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
52
    seed: int,
53
    kv_cache_dtype: str,
54
    device: str,
55
) -> None:
Joe's avatar
Joe committed
56
57
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
58
59
    random.seed(seed)
    torch.random.manual_seed(seed)
60
61
62
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
63
64
65
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
66
67
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
68
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
69
    block_mapping: List[Tuple[int, int]] = []
70
    for i in range(num_mappings):
71
72
73
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
74
75
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
76
77
78
79

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
80
                                                head_size, kv_cache_dtype,
81
                                                dtype, seed, device)
82
83
84
85

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
86
87

    # Call the copy blocks kernel.
88
89
90
91
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)
    ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
92

93
    # Run the reference implementation.
94
95
96
97
98
    for src, dst in block_mapping:
        for cloned_key_cache in cloned_key_caches:
            cloned_key_cache[dst].copy_(cloned_key_cache[src])
        for cloned_value_cache in cloned_value_caches:
            cloned_value_cache[dst].copy_(cloned_value_cache[src])
99
100
101
102

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
103
104
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
105
106
107
        assert torch.allclose(value_cache, cloned_value_cache)


108
109
110
111
112
113
114
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
115
@pytest.mark.parametrize("device", CUDA_DEVICES)
116
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
117
@torch.inference_mode()
118
119
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
122
123
124
125
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
126
    seed: int,
127
    device: str,
128
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
129
) -> None:
Joe's avatar
Joe committed
130
131
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
132
133
    random.seed(seed)
    torch.random.manual_seed(seed)
134
135
136
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
137
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
138
    num_slots = block_size * num_blocks
139
140
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
141
142

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
    _, key, value = qkv.unbind(dim=1)

145
146
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
147
148
149
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
150
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
151

152
    # Clone the KV caches.
153
154
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
155
        ops.convert_fp8(cloned_key_cache, key_cache)
156
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
157
        ops.convert_fp8(cloned_value_cache, value_cache)
158
159
160
161
162
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
163
    k_scale = v_scale = 1.0
Woosuk Kwon's avatar
Woosuk Kwon committed
164

165
    # Call the reshape_and_cache kernel.
166
    ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
167
                          kv_cache_dtype, k_scale, v_scale)
168
169
170

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
171
        ops.convert_fp8(result_key_cache, key_cache)
172
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
173
        ops.convert_fp8(result_value_cache, value_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
174

175
176
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
177
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
178
    block_indicies_lst = block_indicies.cpu().tolist()
179
    block_offsets = slot_mapping % block_size
180
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
181
    for i in range(num_tokens):
182
183
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
184
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
185
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
186

187
188
189
190
191
192
193
194
195
196
197
198
    if kv_cache_dtype == "fp8":
        assert torch.allclose(result_key_cache,
                              cloned_key_cache,
                              atol=0.001,
                              rtol=0.1)
        assert torch.allclose(result_value_cache,
                              cloned_value_cache,
                              atol=0.001,
                              rtol=0.1)
    else:
        assert torch.allclose(key_cache, cloned_key_cache)
        assert torch.allclose(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
199
200


201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
226
    torch.set_default_device(device)
227
228
229

    # Create a random slot mapping.
    num_slots = block_size * num_blocks
230
231
232
233
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device=device)
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
252
        device=device,
253
    )
254
255
256
257
    key_cache, value_cache = key_caches[0].contiguous(
    ), value_caches[0].contiguous()
    del key_caches
    del value_caches
258
259

    # Clone the KV caches.
260
261
262
263
264
265
266
267
268
269
270
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
        ops.convert_fp8(cloned_key_cache, key_cache)
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
        ops.convert_fp8(cloned_value_cache, value_cache)
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
    k_scale = v_scale = 1.0
271
272

    # Call the reshape_and_cache kernel.
273
    ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
274
275
276
277
278
279
280
                                slot_mapping, kv_cache_dtype, k_scale, v_scale)

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
        ops.convert_fp8(result_key_cache, key_cache)
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
        ops.convert_fp8(result_value_cache, value_cache)
281
282

    # Run the reference implementation.
283
284
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indicies_lst = block_indicies.cpu().tolist()
285
    block_offsets = slot_mapping % block_size
286
    block_offsets_lst = block_offsets.cpu().tolist()
287
    for i in range(num_tokens):
288
289
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
290
291
292
        cloned_key_cache[block_idx, block_offset, :, :] = key[i]
        cloned_value_cache[block_idx, block_offset, :, :] = value[i]

293
294
295
296
297
298
299
300
301
302
303
304
    if kv_cache_dtype == "fp8":
        assert torch.allclose(result_key_cache,
                              cloned_key_cache,
                              atol=0.001,
                              rtol=0.1)
        assert torch.allclose(result_value_cache,
                              cloned_value_cache,
                              atol=0.001,
                              rtol=0.1)
    else:
        assert torch.allclose(key_cache, cloned_key_cache)
        assert torch.allclose(value_cache, cloned_value_cache)
305
306


Vladimir's avatar
Vladimir committed
307
308
309
310
311
312
313
314
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
315
@pytest.mark.parametrize("device", CUDA_DEVICES)
316
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
317
318
319
320
321
322
323
324
325
326
327
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
328
    device: str,
329
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
330
) -> None:
331
332
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Joe's avatar
Joe committed
333
334
    if kv_cache_dtype == "fp8" and head_size % 16:
        pytest.skip()
Vladimir's avatar
Vladimir committed
335
336
    random.seed(seed)
    torch.random.manual_seed(seed)
337
338
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
339
340
341

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
342
343
344
345
346
347
348
349
350

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

351
352
353
354
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)
Vladimir's avatar
Vladimir committed
355
356
357

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
358
359
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
360
361
362

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
363
364
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
365
366
367
368
369

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
370
371
372
373
    ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                    block_mapping_tensor)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                    block_mapping_tensor)
Vladimir's avatar
Vladimir committed
374

375
    for src, dst in block_mapping:
Vladimir's avatar
Vladimir committed
376
377
378
379
        assert torch.allclose(src_key_caches_clone[src].cpu(),
                              dist_key_caches[0][dst].cpu())
        assert torch.allclose(src_value_caches_clone[src].cpu(),
                              dist_value_caches[0][dst].cpu())
380
381


zhuwenwen's avatar
zhuwenwen committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_fp8_e4m3_conversion(
#     num_heads: int,
#     head_size: int,
#     block_size: int,
#     num_blocks: int,
#     dtype: torch.dtype,
#     seed: int,
#     device: str,
# ) -> None:
#     random.seed(seed)
#     torch.random.manual_seed(seed)
#     torch.cuda.manual_seed(seed)

#     low = -224.0
#     high = 224.0
#     shape = (num_blocks, num_heads, head_size, block_size)
#     cache = torch.empty(shape, dtype=dtype, device=device)
#     cache.uniform_(low, high)

#     cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
#     ops.convert_fp8(cache_fp8, cache)

#     converted_cache = torch.empty_like(cache)
#     ops.convert_fp8(converted_cache, cache_fp8)

#     assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)