"vllm/vscode:/vscode.git/clone" did not exist on "cf3eacfe58fa9e745c2854782ada884a9f992cf7"
test_cache.py 13.5 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import random
2
from typing import Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch

7
from vllm import _custom_ops as ops
8
from vllm._C import cache_ops
9
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
10

Vladimir's avatar
Vladimir committed
11
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
12
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
13
NUM_TOKENS = [42]  # Arbitrary values for testing
14
NUM_LAYERS = [1]  # Arbitrary values for testing
15
16
17
NUM_HEADS = [8]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
18
19
20
21
22

# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

23
NUM_MAPPINGS = [256]  # Arbitrary values for testing
24
SEEDS = [0]
25
26
27
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
zhuwenwen's avatar
zhuwenwen committed
28
KV_CACHE_DTYPE = ["auto", "fp8"] if not is_hip() else ["auto"]
29
30
31
32
33
34
35
36
37
38


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
39
@pytest.mark.parametrize("device", CUDA_DEVICES)
40
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
41
@torch.inference_mode()
42
43
def test_copy_blocks(
    kv_cache_factory,
44
45
46
47
48
49
50
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
51
    seed: int,
52
    kv_cache_dtype: str,
53
    device: str,
54
) -> None:
55
56
    random.seed(seed)
    torch.random.manual_seed(seed)
57
58
59
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
60
61
62
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
63
64
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
65
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
66
    block_mapping = {}
67
    for i in range(num_mappings):
68
69
70
71
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
        block_mapping[src] = [dst1, dst2]
72
73
74
75

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
76
                                                head_size, kv_cache_dtype,
77
                                                dtype, seed, device)
78
79
80
81

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
82
83

    # Call the copy blocks kernel.
84
    ops.copy_blocks(key_caches, value_caches, block_mapping)
85

86
    # Run the reference implementation.
87
88
89
90
91
92
    for src, dsts in block_mapping.items():
        for dst in dsts:
            for cloned_key_cache in cloned_key_caches:
                cloned_key_cache[dst].copy_(cloned_key_cache[src])
            for cloned_value_cache in cloned_value_caches:
                cloned_value_cache[dst].copy_(cloned_value_cache[src])
93
94
95
96

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
97
98
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
99
100
101
        assert torch.allclose(value_cache, cloned_value_cache)


102
103
104
105
106
107
108
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
109
@pytest.mark.parametrize("device", CUDA_DEVICES)
110
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
111
@torch.inference_mode()
112
113
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
116
117
118
119
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
120
    seed: int,
121
    device: str,
122
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
123
) -> None:
124
125
    if not is_hip() and kv_cache_dtype == "fp8":
        pytest.skip()  # This test is not tuned for e5m2 cuda precision
126
127
    random.seed(seed)
    torch.random.manual_seed(seed)
128
129
130
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
131
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
134
135
136
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
    _, key, value = qkv.unbind(dim=1)

139
140
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
141
142
143
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
144
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
145

146
    # Clone the KV caches.
147
148
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
149
        ops.convert_fp8(key_cache, cloned_key_cache)
150
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
151
        ops.convert_fp8(value_cache, cloned_value_cache)
152
153
154
155
156
157
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
    kv_scale = 1.0
Woosuk Kwon's avatar
Woosuk Kwon committed
158

159
    # Call the reshape_and_cache kernel.
160
161
    ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
                          kv_cache_dtype, kv_scale)
162
163
164

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
165
        ops.convert_fp8(key_cache, result_key_cache)
166
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
167
        ops.convert_fp8(value_cache, result_value_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
168

169
170
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
171
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
172
173
174
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
175
    for i in range(num_tokens):
176
177
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
178
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
179
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
180

181
182
183
184
185
186
187
188
189
190
191
192
    if kv_cache_dtype == "fp8":
        assert torch.allclose(result_key_cache,
                              cloned_key_cache,
                              atol=0.001,
                              rtol=0.1)
        assert torch.allclose(result_value_cache,
                              cloned_value_cache,
                              atol=0.001,
                              rtol=0.1)
    else:
        assert torch.allclose(key_cache, cloned_key_cache)
        assert torch.allclose(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
193
194


195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    if kv_cache_dtype == "fp8":
        pytest.skip()
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    # Create a random slot mapping.
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda')

    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device=device)
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
    )
    key_cache, value_cache = key_caches[0], value_caches[0]

    # Clone the KV caches.
    cloned_key_cache = key_cache.clone()
    cloned_value_cache = value_cache.clone()

    # Call the reshape_and_cache kernel.
    cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
                                      slot_mapping, kv_cache_dtype)

    # Run the reference implementation.
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
    for i in range(num_tokens):
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
        cloned_key_cache[block_idx, block_offset, :, :] = key[i]
        cloned_value_cache[block_idx, block_offset, :, :] = value[i]

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)


Vladimir's avatar
Vladimir committed
271
272
273
274
275
276
277
278
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
279
@pytest.mark.parametrize("device", CUDA_DEVICES)
280
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
281
282
283
284
285
286
287
288
289
290
291
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
292
    device: str,
293
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
294
) -> None:
295
296
297
298
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
    if not is_hip() and kv_cache_dtype == "fp8":
        pytest.skip()  # This test is not tuned for e5m2 cuda precision
Vladimir's avatar
Vladimir committed
299
300
    random.seed(seed)
    torch.random.manual_seed(seed)
301
302
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
303
304
305

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
306
307
308
309
310
311
312
313
314
315
316
317
318

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

    block_mapping = dict(zip(src_blocks, dst_blocks))

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
319
320
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
321
322
323

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
324
325
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
326
327
328
329
330

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
331
332
    ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
Vladimir's avatar
Vladimir committed
333
334
335
336
337
338

    for src, dst in block_mapping.items():
        assert torch.allclose(src_key_caches_clone[src].cpu(),
                              dist_key_caches[0][dst].cpu())
        assert torch.allclose(src_value_caches_clone[src].cpu(),
                              dist_value_caches[0][dst].cpu())
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369


@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_conversion(
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
370
    ops.convert_fp8(cache, cache_fp8)
371
372

    converted_cache = torch.empty_like(cache)
373
    ops.convert_fp8(cache_fp8, converted_cache)
374
375

    assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)