cache.py 5.75 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
import random

import torch

5
from cacheflow import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7


8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def test_copy_blocks(
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    # Generate random block mappings.
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remainig_blocks, num_mappings)
    block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}

    # Create the KV cache.
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_caches = []
    for _ in range(num_layers):
        key_cache = torch.randn(
            size=key_cache_shape, dtype=dtype, device='cuda')
        key_caches.append(key_cache)
    cloned_key_caches = []
    for key_cache in key_caches:
        cloned_key_caches.append(key_cache.clone())

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
    value_caches = []
    for _ in range(num_layers):
        value_cache = torch.randn(
            size=value_cache_shape, dtype=dtype, device='cuda')
        value_caches.append(value_cache)
    cloned_value_caches = []
    for value_cache in value_caches:
        cloned_value_caches.append(value_cache.clone())

    # Call the copy blocks kernel.
    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)

    # Reference implementation.
    for src, dsts in block_mapping.items():
        for dst in dsts:
            for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
                cloned_key_cache[dst] = cloned_key_cache[src]
            for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
                cloned_value_cache[dst] = cloned_value_cache[src]

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
    for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
        assert torch.allclose(value_cache, cloned_value_cache)


Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
65
66
67
68
69
70
71
72
73
74
def test_reshape_and_cache(
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')

Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
78
    qkv = torch.randn(
        num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
    _, key, value = qkv.unbind(dim=1)

Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
83
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
    cloned_key_cache = key_cache.clone()

84
    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
    value_cache = torch.randn(
        size=value_cache_shape, dtype=dtype, device='cuda')
    cloned_value_cache = value_cache.clone()

89
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92

    for i in range(num_tokens):
        reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
93
        block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
        block_offset = slot_mapping[i] % block_size
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
96
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
101

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)


102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def test_gather_cached_kv(
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')

    qkv = torch.randn(
        num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
    _, key, value = qkv.unbind(dim=1)

    qkv_clone = qkv.clone()
    _, cloned_key, cloned_value = qkv_clone.unbind(dim=1)

    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
    value_cache = torch.randn(
        size=value_cache_shape, dtype=dtype, device='cuda')

    cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)

    # Reference implementation.
    for i in range(num_tokens):
        reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x)
        block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
        block_offset = slot_mapping[i] % block_size
        reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
        cloned_value[i] = value_cache[block_idx, :, :, block_offset]

    assert torch.allclose(key, cloned_key)
    assert torch.allclose(value, cloned_value)


143
144
@torch.inference_mode()
def test_cache() -> None:
145
146
147
148
149
150
151
152
153
154
    for dtype in [torch.half, torch.bfloat16, torch.float]:
        test_copy_blocks(
            num_mappings=23, num_layers=7, num_heads=17, head_size=16,
            block_size=8, num_blocks=1024, dtype=dtype)
        test_reshape_and_cache(
            num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
            dtype=dtype)
        test_gather_cached_kv(
            num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
            dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
157


if __name__ == '__main__':
158
    test_cache()