test_cache.py 8.24 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import random
2
from typing import Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch

7
from vllm._C import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
8

Vladimir's avatar
Vladimir committed
9
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
10
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
11
NUM_TOKENS = [42]  # Arbitrary values for testing
12
NUM_LAYERS = [1]  # Arbitrary values for testing
13
14
15
NUM_HEADS = [8]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
16
17
18
19
20

# Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000]

21
NUM_MAPPINGS = [256]  # Arbitrary values for testing
22
SEEDS = [0]
23
24
25
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
26
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]  if not is_hip() else ["auto"]
27
28
29
30
31
32
33
34
35
36


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
37
@pytest.mark.parametrize("device", CUDA_DEVICES)
38
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
39
@torch.inference_mode()
40
41
def test_copy_blocks(
    kv_cache_factory,
42
43
44
45
46
47
48
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
49
    seed: int,
50
    kv_cache_dtype: str,
51
    device: str,
52
) -> None:
53
54
    random.seed(seed)
    torch.random.manual_seed(seed)
55
56
57
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
58
59
60
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
61
62
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
63
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
64
    block_mapping = {}
65
    for i in range(num_mappings):
66
67
68
69
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
        block_mapping[src] = [dst1, dst2]
70
71
72
73

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
74
                                                head_size, kv_cache_dtype,
75
                                                dtype, seed, device)
76
77
78
79

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
80
81

    # Call the copy blocks kernel.
82
    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
83

84
    # Run the reference implementation.
85
86
87
88
89
90
    for src, dsts in block_mapping.items():
        for dst in dsts:
            for cloned_key_cache in cloned_key_caches:
                cloned_key_cache[dst].copy_(cloned_key_cache[src])
            for cloned_value_cache in cloned_value_caches:
                cloned_value_cache[dst].copy_(cloned_value_cache[src])
91
92
93
94

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
95
96
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
97
98
99
        assert torch.allclose(value_cache, cloned_value_cache)


100
101
102
103
104
105
106
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
107
@pytest.mark.parametrize("device", CUDA_DEVICES)
108
@torch.inference_mode()
109
110
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
114
115
116
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
117
    seed: int,
118
    device: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
119
) -> None:
120
121
    random.seed(seed)
    torch.random.manual_seed(seed)
122
123
124
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
125
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
128
129
130
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)

    qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
    _, key, value = qkv.unbind(dim=1)

133
134
135
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
                                                num_heads, head_size, dtype,
136
                                                None, seed, device)
137
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
138

139
140
    # Clone the KV caches.
    cloned_key_cache = key_cache.clone()
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
    cloned_value_cache = value_cache.clone()

143
    # Call the reshape_and_cache kernel.
144
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
145
                                slot_mapping, "auto")
Woosuk Kwon's avatar
Woosuk Kwon committed
146

147
148
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
149
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
150
151
152
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
153
    for i in range(num_tokens):
154
155
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
156
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
157
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
160

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)
Vladimir's avatar
Vladimir committed
161
162
163
164
165
166
167
168
169
170


@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
171
@pytest.mark.parametrize("device", CUDA_DEVICES)
Vladimir's avatar
Vladimir committed
172
173
174
175
176
177
178
179
180
181
182
@torch.inference_mode()
def test_swap_blocks(
    kv_cache_factory,
    direction: Tuple[str, str],
    num_mappings: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
    seed: int,
183
    device: str,
Vladimir's avatar
Vladimir committed
184
185
186
) -> None:
    random.seed(seed)
    torch.random.manual_seed(seed)
187
188
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
189
190
191

    src_device = device if direction[0] == "cuda" else 'cpu'
    dst_device = device if direction[1] == "cuda" else 'cpu'
Vladimir's avatar
Vladimir committed
192
193
194
195
196
197
198
199
200
201
202
203
204

    src_blocks = random.sample(range(num_blocks), num_mappings)
    # For the same device, mapping must not overlap
    if src_device == dst_device:
        remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
        dst_blocks = random.sample(remaining_blocks, num_mappings)
    else:
        dst_blocks = random.sample(range(num_blocks), num_mappings)

    block_mapping = dict(zip(src_blocks, dst_blocks))

    # Create the KV caches on the first device.
    src_key_caches, src_value_caches = kv_cache_factory(
205
        num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
Vladimir's avatar
Vladimir committed
206
207
208
209
        src_device)

    # Create the KV caches on the second device.
    dist_key_caches, dist_value_caches = kv_cache_factory(
210
        num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
Vladimir's avatar
Vladimir committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        dst_device)

    src_key_caches_clone = src_key_caches[0].clone()
    src_value_caches_clone = src_value_caches[0].clone()

    # Call the swap_blocks kernel.
    cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
    cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
                          block_mapping)

    for src, dst in block_mapping.items():
        assert torch.allclose(src_key_caches_clone[src].cpu(),
                              dist_key_caches[0][dst].cpu())
        assert torch.allclose(src_value_caches_clone[src].cpu(),
                              dist_value_caches[0][dst].cpu())