test_cache.py 8.34 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
import random

3
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import torch

Vladimir's avatar
Vladimir committed
6
7
from typing import Tuple

8
from vllm._C import cache_ops
9
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
10

Vladimir's avatar
Vladimir committed
11
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
12
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
13
NUM_TOKENS = [42]  # Arbitrary values for testing
14
NUM_LAYERS = [1]  # Arbitrary values for testing
15
16
17
NUM_HEADS = [8]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
18
19
20
21
# reduce the size for ROCm test to avoid HIP OOM
NUM_BLOCKS = [1024, 36000] if not is_hip else [
    1024, 10000
]  # Arbitrary values for testing
22
NUM_MAPPINGS = [256]  # Arbitrary values for testing
23
SEEDS = [0]
24
25
26
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
27
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
28
29
30
31
32
33
34
35
36
37


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
38
@pytest.mark.parametrize("device", CUDA_DEVICES)
39
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
40
@torch.inference_mode()
41
42
def test_copy_blocks(
    kv_cache_factory,
43
44
45
46
47
48
49
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
50
    seed: int,
51
    kv_cache_dtype: str,
52
    device: str,
53
) -> None:
54
55
    random.seed(seed)
    torch.random.manual_seed(seed)
56
57
58
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
59
60
61
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
62
63
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
64
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
65
    block_mapping = {}
66
    for i in range(num_mappings):
67
68
69
70
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
        block_mapping[src] = [dst1, dst2]
71
72
73
74

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
75
                                                head_size, kv_cache_dtype,
76
                                                dtype, seed, device)
77
78
79
80

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
81
82

    # Call the copy blocks kernel.
83
    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
84

85
    # Run the reference implementation.
86
87
88
89
90
91
    for src, dsts in block_mapping.items():
        for dst in dsts:
            for cloned_key_cache in cloned_key_caches:
                cloned_key_cache[dst].copy_(cloned_key_cache[src])
            for cloned_value_cache in cloned_value_caches:
                cloned_value_cache[dst].copy_(cloned_value_cache[src])
92
93
94
95

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
96
97
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
98
99
100
        assert torch.allclose(value_cache, cloned_value_cache)


101
102
103
104
105
106
107
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
108
@pytest.mark.parametrize("device", CUDA_DEVICES)
109
@torch.inference_mode()
110
111
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
114
115
116
117
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
118
    seed: int,
119
    device: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
120
) -> None:
121
122
    random.seed(seed)
    torch.random.manual_seed(seed)
123
124
125
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
126
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
129
130
131
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
    _, key, value = qkv.unbind(dim=1)

134
135
136
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
                                                num_heads, head_size, dtype,
137
                                                None, seed, device)
138
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
139

140
141
    # Clone the KV caches.
    cloned_key_cache = key_cache.clone()
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
    cloned_value_cache = value_cache.clone()

144
    # Call the reshape_and_cache kernel.
145
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
146
                                slot_mapping, "auto")
Woosuk Kwon's avatar
Woosuk Kwon committed
147

148
149
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
150
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
151
152
153
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
154
    for i in range(num_tokens):
155
156
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
158
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160
161

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
162
163
164
165
166
167
168
169
170
171


@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
172
@pytest.mark.parametrize("device", CUDA_DEVICES)
Vladimir's avatar
Vladimir committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
    device: int,
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
188
189
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
Vladimir's avatar
Vladimir committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    src_device = f"{direction[0]}:{device}" if direction[
        0] == "cuda" else direction[0]
    dst_device = f"{direction[1]}:{device}" if direction[
        1] == "cuda" else direction[1]

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

    block_mapping = dict(zip(src_blocks, dst_blocks))

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
207
        num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
Vladimir's avatar
Vladimir committed
208
209
210
211
        src_device)

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
212
        num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
Vladimir's avatar
Vladimir committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        dst_device)

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
    cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
    cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                          block_mapping)

    for src, dst in block_mapping.items():
        assert torch.allclose(src_key_caches_clone[src].cpu(),
                              dist_key_caches[0][dst].cpu())
        assert torch.allclose(src_value_caches_clone[src].cpu(),
                              dist_value_caches[0][dst].cpu())