"docs/vscode:/vscode.git/clone" did not exist on "bc61f10948c019c33356270557d08d9fdfa8b5a2"
test_cache.py 5.94 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
import random

import torch

Woosuk Kwon's avatar
Woosuk Kwon committed
5
from vllm import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7


8
9
@torch.inference_mode()
def run_copy_blocks(
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    # Generate random block mappings.
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
    dst_blocks = random.sample(remainig_blocks, num_mappings)
    block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}

    # Create the KV cache.
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_caches = []
    for _ in range(num_layers):
        key_cache = torch.randn(
            size=key_cache_shape, dtype=dtype, device='cuda')
        key_caches.append(key_cache)
    cloned_key_caches = []
    for key_cache in key_caches:
        cloned_key_caches.append(key_cache.clone())

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
    value_caches = []
    for _ in range(num_layers):
        value_cache = torch.randn(
            size=value_cache_shape, dtype=dtype, device='cuda')
        value_caches.append(value_cache)
    cloned_value_caches = []
    for value_cache in value_caches:
        cloned_value_caches.append(value_cache.clone())

    # Call the copy blocks kernel.
    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)

    # Reference implementation.
    for src, dsts in block_mapping.items():
        for dst in dsts:
            for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
                cloned_key_cache[dst] = cloned_key_cache[src]
            for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
                cloned_value_cache[dst] = cloned_value_cache[src]

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
    for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
        assert torch.allclose(value_cache, cloned_value_cache)


64
65
@torch.inference_mode()
def run_reshape_and_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
69
70
71
72
73
74
75
76
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')

Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
    qkv = torch.randn(
        num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
    _, key, value = qkv.unbind(dim=1)

Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
84
85
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
    cloned_key_cache = key_cache.clone()

86
    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
90
    value_cache = torch.randn(
        size=value_cache_shape, dtype=dtype, device='cuda')
    cloned_value_cache = value_cache.clone()

91
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94

    for i in range(num_tokens):
        reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
        block_offset = slot_mapping[i] % block_size
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
98
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
101
102
103

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)


104
105
@torch.inference_mode()
def run_gather_cached_kv(
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')

    qkv = torch.randn(
        num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
    _, key, value = qkv.unbind(dim=1)

    qkv_clone = qkv.clone()
    _, cloned_key, cloned_value = qkv_clone.unbind(dim=1)

    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
    value_cache = torch.randn(
        size=value_cache_shape, dtype=dtype, device='cuda')

    cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)

    # Reference implementation.
    for i in range(num_tokens):
        reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x)
        block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
        block_offset = slot_mapping[i] % block_size
        reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
        cloned_value[i] = value_cache[block_idx, :, :, block_offset]

    assert torch.allclose(key, cloned_key)
    assert torch.allclose(value, cloned_value)


146
def test_copy_blocks() -> None:
147
    for dtype in [torch.half, torch.bfloat16, torch.float]:
148
        run_copy_blocks(
149
150
            num_mappings=23, num_layers=7, num_heads=17, head_size=16,
            block_size=8, num_blocks=1024, dtype=dtype)
151
152
153
154
155


def test_reshape_and_cache() -> None:
    for dtype in [torch.half, torch.bfloat16, torch.float]:
        run_reshape_and_cache(
156
157
            num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
            dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159


160
161
162
163
164
def test_gather_cached_kv() -> None:
    for dtype in [torch.half, torch.bfloat16, torch.float]:
        run_gather_cached_kv(
            num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
            dtype=dtype)