test_cache.py 8.2 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
import random

3
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import torch

Vladimir's avatar
Vladimir committed
6
7
from typing import Tuple

8
from vllm._C import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Vladimir's avatar
Vladimir committed
10
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
11
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
12
NUM_TOKENS = [42]  # Arbitrary values for testing
13
NUM_LAYERS = [1]  # Arbitrary values for testing
14
15
16
NUM_HEADS = [8]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
Simon Mo's avatar
Simon Mo committed
17
NUM_BLOCKS = [1024, 3600]  # Arbitrary values for testing
18
NUM_MAPPINGS = [256]  # Arbitrary values for testing
19
SEEDS = [0]
20
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
21
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
22
23
24
25
26
27
28
29
30
31


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
32
@pytest.mark.parametrize("device", DEVICES)
33
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
34
@torch.inference_mode()
35
36
def test_copy_blocks(
    kv_cache_factory,
37
38
39
40
41
42
43
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
44
    seed: int,
45
    device: int,
46
    kv_cache_dtype: str,
47
) -> None:
48
49
50
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
51
    gpu_id = f"cuda:{device}"
52
53
54
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
55
56
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
57
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
58
    block_mapping = {}
59
    for i in range(num_mappings):
60
61
62
63
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
        block_mapping[src] = [dst1, dst2]
64
65
66
67

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
68
69
                                                head_size, kv_cache_dtype,
                                                dtype, seed, gpu_id)
70
71
72
73

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
74
75

    # Call the copy blocks kernel.
76
    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
77

78
    # Run the reference implementation.
79
80
81
82
83
84
    for src, dsts in block_mapping.items():
        for dst in dsts:
            for cloned_key_cache in cloned_key_caches:
                cloned_key_cache[dst].copy_(cloned_key_cache[src])
            for cloned_value_cache in cloned_value_caches:
                cloned_value_cache[dst].copy_(cloned_value_cache[src])
85
86
87
88

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
89
90
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
91
92
93
        assert torch.allclose(value_cache, cloned_value_cache)


94
95
96
97
98
99
100
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
101
@pytest.mark.parametrize("device", DEVICES)
102
@torch.inference_mode()
103
104
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
107
108
109
110
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
111
    seed: int,
112
    device: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
113
) -> None:
114
115
116
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
117
    gpu_id = f"cuda:{device}"
118
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
121
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
122

123
124
125
126
127
    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
128
                      device=gpu_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
    _, key, value = qkv.unbind(dim=1)

131
132
133
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
                                                num_heads, head_size, dtype,
134
                                                None, seed, gpu_id)
135
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
136

137
138
    # Clone the KV caches.
    cloned_key_cache = key_cache.clone()
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
    cloned_value_cache = value_cache.clone()

141
    # Call the reshape_and_cache kernel.
142
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
143
                                slot_mapping, "auto")
Woosuk Kwon's avatar
Woosuk Kwon committed
144

145
146
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
147
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
148
149
150
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
151
    for i in range(num_tokens):
152
153
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
154
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
155
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
158

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202


@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: int,
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    src_device = f"{direction[0]}:{device}" if direction[
        0] == "cuda" else direction[0]
    dst_device = f"{direction[1]}:{device}" if direction[
        1] == "cuda" else direction[1]

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

    block_mapping = dict(zip(src_blocks, dst_blocks))

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
203
        num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
Vladimir's avatar
Vladimir committed
204
205
206
207
        src_device)

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
208
        num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
Vladimir's avatar
Vladimir committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        dst_device)

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
    cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
    cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                          block_mapping)

    for src, dst in block_mapping.items():
        assert torch.allclose(src_key_caches_clone[src].cpu(),
                              dist_key_caches[0][dst].cpu())
        assert torch.allclose(src_value_caches_clone[src].cpu(),
                              dist_value_caches[0][dst].cpu())