test_cache.py 14.8 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import random
2
from typing import List, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch

7
from vllm import _custom_ops as ops
Woosuk Kwon's avatar
Woosuk Kwon committed
8

Vladimir's avatar
Vladimir committed
9
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
10
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
11
NUM_TOKENS = [42]  # Arbitrary values for testing
12
NUM_LAYERS = [1]  # Arbitrary values for testing
13
NUM_HEADS = [8]  # Arbitrary values for testing
14
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
15
BLOCK_SIZES = [8, 16, 32]
16
17
18
19
20

# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

21
NUM_MAPPINGS = [256]  # Arbitrary values for testing
22
SEEDS = [0]
23
24
25
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
26
27

# We assume fp8 is always enabled for testing.
28
KV_CACHE_DTYPE = ["auto", "fp8"]
29
30
31
32
33
34
35
36
37
38


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
39
@pytest.mark.parametrize("device", CUDA_DEVICES)
40
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
41
@torch.inference_mode()
42
43
def test_copy_blocks(
    kv_cache_factory,
44
45
46
47
48
49
50
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
51
    seed: int,
52
    kv_cache_dtype: str,
53
    device: str,
54
) -> None:
55
56
    random.seed(seed)
    torch.random.manual_seed(seed)
57
58
59
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
60
61
62
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
63
64
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
65
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
66
    block_mapping: List[Tuple[int, int]] = []
67
    for i in range(num_mappings):
68
69
70
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
71
72
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
73
74
75
76

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
77
                                                head_size, kv_cache_dtype,
78
                                                dtype, seed, device)
79
80
81
82

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
83
84

    # Call the copy blocks kernel.
85
86
87
88
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)
    ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
89

90
    # Run the reference implementation.
91
92
93
94
95
    for src, dst in block_mapping:
        for cloned_key_cache in cloned_key_caches:
            cloned_key_cache[dst].copy_(cloned_key_cache[src])
        for cloned_value_cache in cloned_value_caches:
            cloned_value_cache[dst].copy_(cloned_value_cache[src])
96
97
98
99

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
100
101
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
102
103
104
        assert torch.allclose(value_cache, cloned_value_cache)


105
106
107
108
109
110
111
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
112
@pytest.mark.parametrize("device", CUDA_DEVICES)
113
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
114
@torch.inference_mode()
115
116
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119
120
121
122
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
123
    seed: int,
124
    device: str,
125
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
126
) -> None:
127
128
    random.seed(seed)
    torch.random.manual_seed(seed)
129
130
131
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
132
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
133
    num_slots = block_size * num_blocks
134
135
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
136
137

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
    _, key, value = qkv.unbind(dim=1)

140
141
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
142
143
144
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
145
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
146

147
    # Clone the KV caches.
148
149
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
150
        ops.convert_fp8(cloned_key_cache, key_cache)
151
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
152
        ops.convert_fp8(cloned_value_cache, value_cache)
153
154
155
156
157
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
158
    k_scale = v_scale = 1.0
Woosuk Kwon's avatar
Woosuk Kwon committed
159

160
    # Call the reshape_and_cache kernel.
161
    ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
162
                          kv_cache_dtype, k_scale, v_scale)
163
164
165

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
166
        ops.convert_fp8(result_key_cache, key_cache)
167
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
168
        ops.convert_fp8(result_value_cache, value_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
169

170
171
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
172
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
173
    block_indicies_lst = block_indicies.cpu().tolist()
174
    block_offsets = slot_mapping % block_size
175
    block_offsets_lst = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
176
    for i in range(num_tokens):
177
178
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
179
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
180
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
181

182
183
184
185
186
187
188
189
190
191
192
193
    if kv_cache_dtype == "fp8":
        assert torch.allclose(result_key_cache,
                              cloned_key_cache,
                              atol=0.001,
                              rtol=0.1)
        assert torch.allclose(result_value_cache,
                              cloned_value_cache,
                              atol=0.001,
                              rtol=0.1)
    else:
        assert torch.allclose(key_cache, cloned_key_cache)
        assert torch.allclose(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
194
195


196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
221
    torch.set_default_device(device)
222
223
224

    # Create a random slot mapping.
    num_slots = block_size * num_blocks
225
226
227
228
    slot_mapping_lst = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping_lst,
                                dtype=torch.long,
                                device=device)
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device=device)
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
247
        device=device,
248
    )
249
250
251
252
    key_cache, value_cache = key_caches[0].contiguous(
    ), value_caches[0].contiguous()
    del key_caches
    del value_caches
253
254

    # Clone the KV caches.
255
256
257
258
259
260
261
262
263
264
265
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
        ops.convert_fp8(cloned_key_cache, key_cache)
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
        ops.convert_fp8(cloned_value_cache, value_cache)
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
    k_scale = v_scale = 1.0
266
267

    # Call the reshape_and_cache kernel.
268
    ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
269
270
271
272
273
274
275
                                slot_mapping, kv_cache_dtype, k_scale, v_scale)

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
        ops.convert_fp8(result_key_cache, key_cache)
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
        ops.convert_fp8(result_value_cache, value_cache)
276
277

    # Run the reference implementation.
278
279
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
    block_indicies_lst = block_indicies.cpu().tolist()
280
    block_offsets = slot_mapping % block_size
281
    block_offsets_lst = block_offsets.cpu().tolist()
282
    for i in range(num_tokens):
283
284
        block_idx = block_indicies_lst[i]
        block_offset = block_offsets_lst[i]
285
286
287
        cloned_key_cache[block_idx, block_offset, :, :] = key[i]
        cloned_value_cache[block_idx, block_offset, :, :] = value[i]

288
289
290
291
292
293
294
295
296
297
298
299
    if kv_cache_dtype == "fp8":
        assert torch.allclose(result_key_cache,
                              cloned_key_cache,
                              atol=0.001,
                              rtol=0.1)
        assert torch.allclose(result_value_cache,
                              cloned_value_cache,
                              atol=0.001,
                              rtol=0.1)
    else:
        assert torch.allclose(key_cache, cloned_key_cache)
        assert torch.allclose(value_cache, cloned_value_cache)
300
301


Vladimir's avatar
Vladimir committed
302
303
304
305
306
307
308
309
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
310
@pytest.mark.parametrize("device", CUDA_DEVICES)
311
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
312
313
314
315
316
317
318
319
320
321
322
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
323
    device: str,
324
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
325
) -> None:
326
327
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Vladimir's avatar
Vladimir committed
328
329
    random.seed(seed)
    torch.random.manual_seed(seed)
330
331
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
332
333
334

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
335
336
337
338
339
340
341
342
343

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

344
345
346
347
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)
Vladimir's avatar
Vladimir committed
348
349
350

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
351
352
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
353
354
355

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
356
357
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
358
359
360
361
362

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
363
364
365
366
    ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                    block_mapping_tensor)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                    block_mapping_tensor)
Vladimir's avatar
Vladimir committed
367

368
    for src, dst in block_mapping:
Vladimir's avatar
Vladimir committed
369
370
371
372
        assert torch.allclose(src_key_caches_clone[src].cpu(),
                              dist_key_caches[0][dst].cpu())
        assert torch.allclose(src_value_caches_clone[src].cpu(),
                              dist_value_caches[0][dst].cpu())
373
374
375
376
377
378
379
380
381
382


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
383
def test_fp8_e4m3_conversion(
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
403
    ops.convert_fp8(cache_fp8, cache)
404
405

    converted_cache = torch.empty_like(cache)
406
    ops.convert_fp8(converted_cache, cache_fp8)
407
408

    assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)