test_cache.py 10.9 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import random
2
from typing import Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch

7
from vllm import _custom_ops as ops
8
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Vladimir's avatar
Vladimir committed
10
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
11
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
12
NUM_TOKENS = [42]  # Arbitrary values for testing
13
NUM_LAYERS = [1]  # Arbitrary values for testing
14
15
16
NUM_HEADS = [8]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
17
18
19
20
21

# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

22
NUM_MAPPINGS = [256]  # Arbitrary values for testing
23
SEEDS = [0]
24
25
26
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
27
KV_CACHE_DTYPE = ["auto", "fp8"]
28
29
30
31
32
33
34
35
36
37


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
38
@pytest.mark.parametrize("device", CUDA_DEVICES)
39
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
40
@torch.inference_mode()
41
42
def test_copy_blocks(
    kv_cache_factory,
43
44
45
46
47
48
49
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
50
    seed: int,
51
    kv_cache_dtype: str,
52
    device: str,
53
) -> None:
54
55
    random.seed(seed)
    torch.random.manual_seed(seed)
56
57
58
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
59
60
61
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
62
63
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
64
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
65
    block_mapping = {}
66
    for i in range(num_mappings):
67
68
69
70
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
        block_mapping[src] = [dst1, dst2]
71
72
73
74

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
75
                                                head_size, kv_cache_dtype,
76
                                                dtype, seed, device)
77
78
79
80

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
81
82

    # Call the copy blocks kernel.
83
    ops.copy_blocks(key_caches, value_caches, block_mapping)
84

85
    # Run the reference implementation.
86
87
88
89
90
91
    for src, dsts in block_mapping.items():
        for dst in dsts:
            for cloned_key_cache in cloned_key_caches:
                cloned_key_cache[dst].copy_(cloned_key_cache[src])
            for cloned_value_cache in cloned_value_caches:
                cloned_value_cache[dst].copy_(cloned_value_cache[src])
92
93
94
95

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
96
97
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
98
99
100
        assert torch.allclose(value_cache, cloned_value_cache)


101
102
103
104
105
106
107
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
108
@pytest.mark.parametrize("device", CUDA_DEVICES)
109
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
110
@torch.inference_mode()
111
112
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
116
117
118
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
119
    seed: int,
120
    device: str,
121
    kv_cache_dtype: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
122
) -> None:
123
124
    if not is_hip() and kv_cache_dtype == "fp8":
        pytest.skip()  # This test is not tuned for e5m2 cuda precision
125
126
    random.seed(seed)
    torch.random.manual_seed(seed)
127
128
129
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
130
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
133
134
135
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
    _, key, value = qkv.unbind(dim=1)

138
139
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
140
141
142
                                                num_heads, head_size,
                                                kv_cache_dtype, dtype, seed,
                                                device)
143
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
144

145
    # Clone the KV caches.
146
147
    if kv_cache_dtype == "fp8":
        cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
148
        ops.convert_fp8(key_cache, cloned_key_cache)
149
        cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
150
        ops.convert_fp8(value_cache, cloned_value_cache)
151
152
153
154
155
156
    else:
        cloned_key_cache = key_cache.clone()
        cloned_value_cache = value_cache.clone()

    # Using default kv_scale
    kv_scale = 1.0
Woosuk Kwon's avatar
Woosuk Kwon committed
157

158
    # Call the reshape_and_cache kernel.
159
160
    ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
                          kv_cache_dtype, kv_scale)
161
162
163

    if kv_cache_dtype == "fp8":
        result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
164
        ops.convert_fp8(key_cache, result_key_cache)
165
        result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
166
        ops.convert_fp8(value_cache, result_value_cache)
Woosuk Kwon's avatar
Woosuk Kwon committed
167

168
169
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
170
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
171
172
173
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
174
    for i in range(num_tokens):
175
176
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
177
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
178
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
179

180
181
182
183
184
185
186
187
188
189
190
191
    if kv_cache_dtype == "fp8":
        assert torch.allclose(result_key_cache,
                              cloned_key_cache,
                              atol=0.001,
                              rtol=0.1)
        assert torch.allclose(result_value_cache,
                              cloned_value_cache,
                              atol=0.001,
                              rtol=0.1)
    else:
        assert torch.allclose(key_cache, cloned_key_cache)
        assert torch.allclose(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
192
193
194
195
196
197
198
199
200
201


@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
202
@pytest.mark.parametrize("device", CUDA_DEVICES)
203
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
Vladimir's avatar
Vladimir committed
204
205
206
207
208
209
210
211
212
213
214
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
215
    device: str,
216
    kv_cache_dtype: str,
Vladimir's avatar
Vladimir committed
217
) -> None:
218
219
220
221
    if kv_cache_dtype == "fp8" and "cpu" in direction:
        pytest.skip()
    if not is_hip() and kv_cache_dtype == "fp8":
        pytest.skip()  # This test is not tuned for e5m2 cuda precision
Vladimir's avatar
Vladimir committed
222
223
    random.seed(seed)
    torch.random.manual_seed(seed)
224
225
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
226
227
228

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
229
230
231
232
233
234
235
236
237
238
239
240
241

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

    block_mapping = dict(zip(src_blocks, dst_blocks))

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
242
243
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, src_device)
Vladimir's avatar
Vladimir committed
244
245
246

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
247
248
        num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
        seed, dst_device)
Vladimir's avatar
Vladimir committed
249
250
251
252
253

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
254
255
    ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
    ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
Vladimir's avatar
Vladimir committed
256
257
258
259
260
261

    for src, dst in block_mapping.items():
        assert torch.allclose(src_key_caches_clone[src].cpu(),
                              dist_key_caches[0][dst].cpu())
        assert torch.allclose(src_value_caches_clone[src].cpu(),
                              dist_value_caches[0][dst].cpu())
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292


@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_conversion(
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    low = -224.0
    high = 224.0
    shape = (num_blocks, num_heads, head_size, block_size)
    cache = torch.empty(shape, dtype=dtype, device=device)
    cache.uniform_(low, high)

    cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
293
    ops.convert_fp8(cache, cache_fp8)
294
295

    converted_cache = torch.empty_like(cache)
296
    ops.convert_fp8(cache_fp8, converted_cache)
297
298

    assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)