cache.py 4.03 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
import random

import torch

5
from cacheflow import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7


8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def test_copy_blocks(
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    # Generate random block mappings.
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remainig_blocks, num_mappings)
    block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}

    # Create the KV cache.
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_caches = []
    for _ in range(num_layers):
        key_cache = torch.randn(
            size=key_cache_shape, dtype=dtype, device='cuda')
        key_caches.append(key_cache)
    cloned_key_caches = []
    for key_cache in key_caches:
        cloned_key_caches.append(key_cache.clone())

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
    value_caches = []
    for _ in range(num_layers):
        value_cache = torch.randn(
            size=value_cache_shape, dtype=dtype, device='cuda')
        value_caches.append(value_cache)
    cloned_value_caches = []
    for value_cache in value_caches:
        cloned_value_caches.append(value_cache.clone())

    # Call the copy blocks kernel.
    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)

    # Reference implementation.
    for src, dsts in block_mapping.items():
        for dst in dsts:
            for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
                cloned_key_cache[dst] = cloned_key_cache[src]
            for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
                cloned_value_cache[dst] = cloned_value_cache[src]

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
    for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
        assert torch.allclose(value_cache, cloned_value_cache)


Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
65
66
67
68
69
70
71
72
73
74
def test_reshape_and_cache(
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')

Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
78
    qkv = torch.randn(
        num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
    _, key, value = qkv.unbind(dim=1)

Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
83
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
    cloned_key_cache = key_cache.clone()

84
    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
    value_cache = torch.randn(
        size=value_cache_shape, dtype=dtype, device='cuda')
    cloned_value_cache = value_cache.clone()

89
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92

    for i in range(num_tokens):
        reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
93
        block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
        block_offset = slot_mapping[i] % block_size
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
96
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
101

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)


102
103
@torch.inference_mode()
def test_cache() -> None:
104
105
106
    test_copy_blocks(
        num_mappings=23, num_layers=7, num_heads=17, head_size=16,
        block_size=8, num_blocks=1024, dtype=torch.half)
Woosuk Kwon's avatar
Woosuk Kwon committed
107
    test_reshape_and_cache(
108
        num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
111
112
        dtype=torch.half)


if __name__ == '__main__':
113
    test_cache()