test_cache.py 7.06 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
import random

import torch

Woosuk Kwon's avatar
Woosuk Kwon committed
5
from vllm import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7


8
9
@torch.inference_mode()
def run_copy_blocks(
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    # Generate random block mappings.
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remainig_blocks, num_mappings)
    block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}

    # Create the KV cache.
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_caches = []
    for _ in range(num_layers):
29
30
31
        key_cache = torch.randn(size=key_cache_shape,
                                dtype=dtype,
                                device='cuda')
32
33
34
35
36
37
38
39
        key_caches.append(key_cache)
    cloned_key_caches = []
    for key_cache in key_caches:
        cloned_key_caches.append(key_cache.clone())

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
    value_caches = []
    for _ in range(num_layers):
40
41
42
        value_cache = torch.randn(size=value_cache_shape,
                                  dtype=dtype,
                                  device='cuda')
43
44
45
46
47
48
49
50
51
52
53
        value_caches.append(value_cache)
    cloned_value_caches = []
    for value_cache in value_caches:
        cloned_value_caches.append(value_cache.clone())

    # Call the copy blocks kernel.
    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)

    # Reference implementation.
    for src, dsts in block_mapping.items():
        for dst in dsts:
54
55
            for key_cache, cloned_key_cache in zip(key_caches,
                                                   cloned_key_caches):
56
                cloned_key_cache[dst] = cloned_key_cache[src]
57
58
            for value_cache, cloned_value_cache in zip(value_caches,
                                                       cloned_value_caches):
59
60
61
62
63
                cloned_value_cache[dst] = cloned_value_cache[src]

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
64
65
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
66
67
68
        assert torch.allclose(value_cache, cloned_value_cache)


69
70
@torch.inference_mode()
def run_reshape_and_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74
75
76
77
78
79
80
81
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')

82
83
84
85
86
87
    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
    _, key, value = qkv.unbind(dim=1)

Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
93
94
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
    cloned_key_cache = key_cache.clone()

95
    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
96
97
98
    value_cache = torch.randn(size=value_cache_shape,
                              dtype=dtype,
                              device='cuda')
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
    cloned_value_cache = value_cache.clone()

101
102
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
                                slot_mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
105

    for i in range(num_tokens):
        reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
106
107
108
        block_idx = torch.div(slot_mapping[i],
                              block_size,
                              rounding_mode='floor')
Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
        block_offset = slot_mapping[i] % block_size
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
111
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
114
115
116

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)


117
118
@torch.inference_mode()
def run_gather_cached_kv(
119
120
121
122
123
124
125
126
127
128
129
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')

130
131
132
133
134
135
    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
                      device='cuda')
136
137
138
139
140
141
142
143
144
145
    _, key, value = qkv.unbind(dim=1)

    qkv_clone = qkv.clone()
    _, cloned_key, cloned_value = qkv_clone.unbind(dim=1)

    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
146
147
148
    value_cache = torch.randn(size=value_cache_shape,
                              dtype=dtype,
                              device='cuda')
149

150
151
    cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
                               slot_mapping)
152
153
154

    # Reference implementation.
    for i in range(num_tokens):
155
156
157
158
159
        reshaped_key = cloned_key.reshape(num_tokens, num_heads,
                                          head_size // x, x)
        block_idx = torch.div(slot_mapping[i],
                              block_size,
                              rounding_mode='floor')
160
161
162
163
164
165
166
167
        block_offset = slot_mapping[i] % block_size
        reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
        cloned_value[i] = value_cache[block_idx, :, :, block_offset]

    assert torch.allclose(key, cloned_key)
    assert torch.allclose(value, cloned_value)


168
def test_copy_blocks() -> None:
169
    for dtype in [torch.half, torch.bfloat16, torch.float]:
170
171
172
173
174
175
176
        run_copy_blocks(num_mappings=23,
                        num_layers=7,
                        num_heads=17,
                        head_size=16,
                        block_size=8,
                        num_blocks=1024,
                        dtype=dtype)
177
178
179
180


def test_reshape_and_cache() -> None:
    for dtype in [torch.half, torch.bfloat16, torch.float]:
181
182
183
184
185
186
        run_reshape_and_cache(num_tokens=3,
                              num_heads=2,
                              head_size=16,
                              block_size=8,
                              num_blocks=2,
                              dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
187
188


189
190
def test_gather_cached_kv() -> None:
    for dtype in [torch.half, torch.bfloat16, torch.float]:
191
192
193
194
195
196
        run_gather_cached_kv(num_tokens=3,
                             num_heads=2,
                             head_size=16,
                             block_size=8,
                             num_blocks=2,
                             dtype=dtype)