"vllm/vscode:/vscode.git/clone" did not exist on "8452946c06a3b8a76233d2b390d886a5a8c78182"
test_cache.py 13.6 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import random
2
from typing import Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch

7
from vllm import _custom_ops as ops
Woosuk Kwon's avatar
Woosuk Kwon committed
8

Vladimir's avatar
Vladimir committed
9
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
10
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
11
NUM_TOKENS = [42]  # Arbitrary values for testing
12
NUM_LAYERS = [1]  # Arbitrary values for testing
13
NUM_HEADS = [8]  # Arbitrary values for testing
14
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
15
BLOCK_SIZES = [8, 16, 32]
16
17
18
19
20

# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

21
NUM_MAPPINGS = [256]  # Arbitrary values for testing
22
SEEDS = [0]
23
24
25
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
26
27

# We assume fp8 is always enabled for testing.
zhuwenwen's avatar
zhuwenwen committed
28
29
# KV_CACHE_DTYPE = ["auto", "fp8"] 
KV_CACHE_DTYPE = ["auto"] 
30
31
32
33
34
35
36
37
38
39


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
40
@pytest.mark.parametrize("device", CUDA_DEVICES)
41
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
42
@torch.inference_mode()
43
44
def test_copy_blocks(
    kv_cache_factory,
45
46
47
48
49
50
51
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
52
    seed: int,
53
    kv_cache_dtype: str,
54
    device: str,
55
) -> None:
56
57
    random.seed(seed)
    torch.random.manual_seed(seed)
58
59
60
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
61
62
63
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
64
65
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
66
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
67
    block_mapping = []
68
    for i in range(num_mappings):
69
70
71
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
72
73
        block_mapping.append((src, dst1))
        block_mapping.append((src, dst2))
74
75
76
77

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
78
                                                head_size, kv_cache_dtype,
79
                                                dtype, seed, device)
80
81
82
83

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
84
85

    # Call the copy blocks kernel.
86
87
88
89
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device=device).view(-1, 2)
    ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
90

91
    # Run the reference implementation.
92
93
94
95
96
    for src, dst in block_mapping:
        for cloned_key_cache in cloned_key_caches:
            cloned_key_cache[dst].copy_(cloned_key_cache[src])
        for cloned_value_cache in cloned_value_caches:
            cloned_value_cache[dst].copy_(cloned_value_cache[src])
97
98
99
100

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
101
102
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
103
104
105
        assert torch.allclose(value_cache, cloned_value_cache)


106
107
108
109
110
111
112
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
113
@pytest.mark.parametrize("device", CUDA_DEVICES)
114
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
115
@torch.inference_mode()
116
117
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
121
122
123
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
124
    seed: int,
125
    device: str,
126
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
127
) -> None:
128
129
    random.seed(seed)
    torch.random.manual_seed(seed)
130
131
132
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
133
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
134
135
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
136
137
138
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
    _, key, value = qkv.unbind(dim=1)

141
142
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
143
144
145
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
146
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
147

148
    # Clone the KV caches.
149
150
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
151
        ops.convert_fp8(cloned_key_cache, key_cache)
152
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
153
        ops.convert_fp8(cloned_value_cache, value_cache)
154
155
156
157
158
159
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
    kv_scale = 1.0
Woosuk Kwon's avatar
Woosuk Kwon committed
160

161
    # Call the reshape_and_cache kernel.
162
163
    ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
                          kv_cache_dtype, kv_scale)
164
165
166

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
167
        ops.convert_fp8(result_key_cache, key_cache)
168
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
169
        ops.convert_fp8(result_value_cache, value_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
170

171
172
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
173
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
174
175
176
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
177
    for i in range(num_tokens):
178
179
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
180
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
181
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
182

183
184
185
186
187
188
189
190
191
192
193
194
    if kv_cache_dtype == "fp8":
        assert torch.allclose(result_key_cache,
                              cloned_key_cache,
                              atol=0.001,
                              rtol=0.1)
        assert torch.allclose(result_value_cache,
                              cloned_value_cache,
                              atol=0.001,
                              rtol=0.1)
    else:
        assert torch.allclose(key_cache, cloned_key_cache)
        assert torch.allclose(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
195
196


197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
    kv_cache_factory_flashinfer,
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
    kv_cache_dtype: str,
) -> None:
    if kv_cache_dtype == "fp8":
        pytest.skip()
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
224
    torch.set_default_device(device)
225
226
227
228

    # Create a random slot mapping.
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
229
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device)
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device=device)
    _, key, value = qkv.unbind(dim=1)

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory_flashinfer(
        num_blocks,
        block_size,
        1,
        num_heads,
        head_size,
        kv_cache_dtype,
        dtype,
248
        device=device,
249
250
251
252
253
254
255
256
    )
    key_cache, value_cache = key_caches[0], value_caches[0]

    # Clone the KV caches.
    cloned_key_cache = key_cache.clone()
    cloned_value_cache = value_cache.clone()

    # Call the reshape_and_cache kernel.
257
258
    ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
                                slot_mapping, kv_cache_dtype)
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

    # Run the reference implementation.
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
    for i in range(num_tokens):
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
        cloned_key_cache[block_idx, block_offset, :, :] = key[i]
        cloned_value_cache[block_idx, block_offset, :, :] = value[i]

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)


Vladimir's avatar
Vladimir committed
275
276
277
278
279
280
281
282
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
283
@pytest.mark.parametrize("device", CUDA_DEVICES)
284
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
285
286
287
288
289
290
291
292
293
294
295
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
296
    device: str,
297
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
298
) -> None:
299
300
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
Vladimir's avatar
Vladimir committed
301
302
    random.seed(seed)
    torch.random.manual_seed(seed)
303
304
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
305
306
307

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
308
309
310
311
312
313
314
315
316

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

317
318
319
320
    block_mapping = list(zip(src_blocks, dst_blocks))
    block_mapping_tensor = torch.tensor(block_mapping,
                                        dtype=torch.int64,
                                        device="cpu").view(-1, 2)
Vladimir's avatar
Vladimir committed
321
322
323

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
324
325
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
326
327
328

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
329
330
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
331
332
333
334
335

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
336
337
338
339
    ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                    block_mapping_tensor)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                    block_mapping_tensor)
Vladimir's avatar
Vladimir committed
340

341
    for src, dst in block_mapping:
Vladimir's avatar
Vladimir committed
342
343
344
345
        assert torch.allclose(src_key_caches_clone[src].cpu(),
                              dist_key_caches[0][dst].cpu())
        assert torch.allclose(src_value_caches_clone[src].cpu(),
                              dist_value_caches[0][dst].cpu())
346
347


zhuwenwen's avatar
zhuwenwen committed
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_e4m3_conversion(
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
    ops.convert_fp8(cache_fp8, cache)

    converted_cache = torch.empty_like(cache)
    ops.convert_fp8(converted_cache, cache_fp8)

    assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)