test_cache.py 5.51 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
import random

3
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import torch

6
from vllm._C import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
DTYPES = [torch.half, torch.bfloat16, torch.float]
9
10
NUM_TOKENS = [83]  # Arbitrary values for testing
NUM_LAYERS = [1]  # Arbitrary values for testing
11
12
13
NUM_HEADS = [8]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
14
15
NUM_BLOCKS = [1024, 36000]  # Arbitrary values for testing
NUM_MAPPINGS = [256]  # Arbitrary values for testing
16
SEEDS = [0]
17
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
18
19
20
21
22
23
24
25
26
27


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
28
@pytest.mark.parametrize("device", DEVICES)
29
@torch.inference_mode()
30
31
def test_copy_blocks(
    kv_cache_factory,
32
33
34
35
36
37
38
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
39
    seed: int,
40
    device: int,
41
) -> None:
42
43
44
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
45
    gpu_id = f"cuda:{device}"
46
47
48
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
49
50
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
51
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
52
53
    copy_src = []
    copy_dst = []
54
    for i in range(num_mappings):
55
56
57
58
        copy_src.append(src_blocks[i])
        copy_dst.append(dst_blocks[2 * i])
        copy_src.append(src_blocks[i])
        copy_dst.append(dst_blocks[2 * i + 1])
59
60
61
62

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
63
                                                head_size, dtype, seed, gpu_id)
64
65
66
67

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
68
69

    # Call the copy blocks kernel.
70
    cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst)
71

72
    # Run the reference implementation.
73
74
75
76
77
    for src, dst in zip(copy_src, copy_dst):
        for cloned_key_cache in cloned_key_caches:
            cloned_key_cache[dst].copy_(cloned_key_cache[src])
        for cloned_value_cache in cloned_value_caches:
            cloned_value_cache[dst].copy_(cloned_value_cache[src])
78
79
80
81

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
82
83
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
84
85
86
        assert torch.allclose(value_cache, cloned_value_cache)


87
88
89
90
91
92
93
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
94
@pytest.mark.parametrize("device", DEVICES)
95
@torch.inference_mode()
96
97
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
102
103
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
104
    seed: int,
105
    device: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
106
) -> None:
107
108
109
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
110
    gpu_id = f"cuda:{device}"
111
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
114
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
115

116
117
118
119
120
    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
121
                      device=gpu_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
    _, key, value = qkv.unbind(dim=1)

124
125
126
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
                                                num_heads, head_size, dtype,
127
                                                seed, gpu_id)
128
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
129

130
131
    # Clone the KV caches.
    cloned_key_cache = key_cache.clone()
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
    cloned_value_cache = value_cache.clone()

134
    # Call the reshape_and_cache kernel.
135
136
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
                                slot_mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
137

138
139
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
140
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
141
142
143
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
144
    for i in range(num_tokens):
145
146
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
147
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
148
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)