test_cache.py 5.23 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
import random

3
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import torch

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
NUM_LAYERS = [5]  # Arbitrary values for testing
NUM_HEADS = [8]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024]  # Arbitrary values for testing
NUM_MAPPINGS = [32, 256]  # Arbitrary values for testing
SEEDS = [0]


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
27
@torch.inference_mode()
28
29
def test_copy_blocks(
    kv_cache_factory,
30
31
32
33
34
35
36
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
37
    seed: int,
38
) -> None:
39
40
41
42
43
44
45
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
46
47
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
    block_mapping = {}
    for i in range(num_mappings):
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
        block_mapping[src] = [dst1, dst2]

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
                                                head_size, dtype, seed)

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
64
65
66
67

    # Call the copy blocks kernel.
    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)

68
    # Run the reference implementation.
69
70
    for src, dsts in block_mapping.items():
        for dst in dsts:
71
            for cloned_key_cache in cloned_key_caches:
72
                cloned_key_cache[dst] = cloned_key_cache[src]
73
            for cloned_value_cache in cloned_value_caches:
74
75
76
77
78
                cloned_value_cache[dst] = cloned_value_cache[src]

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
79
80
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
81
82
83
        assert torch.allclose(value_cache, cloned_value_cache)


84
85
86
87
88
89
90
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
91
@torch.inference_mode()
92
93
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
97
98
99
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
100
    seed: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
101
) -> None:
102
103
104
105
106
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
109
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda")
Woosuk Kwon's avatar
Woosuk Kwon committed
110

111
112
113
114
115
    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
116
                      device="cuda")
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
    _, key, value = qkv.unbind(dim=1)

119
120
121
122
123
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
                                                num_heads, head_size, dtype,
                                                seed)
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
124

125
126
    # Clone the KV caches.
    cloned_key_cache = key_cache.clone()
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
    cloned_value_cache = value_cache.clone()

129
    # Call the reshape_and_cache kernel.
130
131
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
                                slot_mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
132

133
134
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
135
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
136
137
138
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
139
    for i in range(num_tokens):
140
141
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
142
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
143
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)