test_cache.py 5.7 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
import random

3
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
import torch

6
from vllm._C import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
DTYPES = [torch.half, torch.bfloat16, torch.float]
Simon Mo's avatar
Simon Mo committed
9
NUM_TOKENS = [42]  # Arbitrary values for testing
10
NUM_LAYERS = [1]  # Arbitrary values for testing
11
12
13
NUM_HEADS = [8]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
Simon Mo's avatar
Simon Mo committed
14
NUM_BLOCKS = [1024, 3600]  # Arbitrary values for testing
15
NUM_MAPPINGS = [256]  # Arbitrary values for testing
16
SEEDS = [0]
17
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
18
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
19
20
21
22
23
24
25
26
27
28


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
29
@pytest.mark.parametrize("device", DEVICES)
30
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
31
@torch.inference_mode()
32
33
def test_copy_blocks(
    kv_cache_factory,
34
35
36
37
38
39
40
    num_mappings: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
41
    seed: int,
42
    device: int,
43
    kv_cache_dtype: str,
44
) -> None:
45
46
47
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
48
    gpu_id = f"cuda:{device}"
49
50
51
    # Generate random block mappings where each source block is mapped to two
    # destination blocks.
    assert 2 * num_mappings <= num_blocks
52
53
    src_blocks = random.sample(range(num_blocks), num_mappings)
    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
54
    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
55
    block_mapping = {}
56
    for i in range(num_mappings):
57
58
59
60
        src = src_blocks[i]
        dst1 = dst_blocks[2 * i]
        dst2 = dst_blocks[2 * i + 1]
        block_mapping[src] = [dst1, dst2]
61
62
63
64

    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                num_layers, num_heads,
65
66
                                                head_size, kv_cache_dtype,
                                                dtype, seed, gpu_id)
67
68
69
70

    # Clone the KV caches.
    cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
    cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
71
72

    # Call the copy blocks kernel.
73
    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
74

75
    # Run the reference implementation.
76
77
78
79
80
81
    for src, dsts in block_mapping.items():
        for dst in dsts:
            for cloned_key_cache in cloned_key_caches:
                cloned_key_cache[dst].copy_(cloned_key_cache[src])
            for cloned_value_cache in cloned_value_caches:
                cloned_value_cache[dst].copy_(cloned_value_cache[src])
82
83
84
85

    # Compare the results.
    for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
        assert torch.allclose(key_cache, cloned_key_cache)
86
87
    for value_cache, cloned_value_cache in zip(value_caches,
                                               cloned_value_caches):
88
89
90
        assert torch.allclose(value_cache, cloned_value_cache)


91
92
93
94
95
96
97
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
98
@pytest.mark.parametrize("device", DEVICES)
99
@torch.inference_mode()
100
101
def test_reshape_and_cache(
    kv_cache_factory,
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104
105
106
107
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
108
    seed: int,
109
    device: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
110
) -> None:
111
112
113
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
114
    gpu_id = f"cuda:{device}"
115
    # Create a random slot mapping.
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
    num_slots = block_size * num_blocks
    slot_mapping = random.sample(range(num_slots), num_tokens)
118
    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
119

120
121
122
123
124
    qkv = torch.randn(num_tokens,
                      3,
                      num_heads,
                      head_size,
                      dtype=dtype,
125
                      device=gpu_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
    _, key, value = qkv.unbind(dim=1)

128
129
130
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
                                                num_heads, head_size, dtype,
131
                                                None, seed, gpu_id)
132
    key_cache, value_cache = key_caches[0], value_caches[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
133

134
135
    # Clone the KV caches.
    cloned_key_cache = key_cache.clone()
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
    cloned_value_cache = value_cache.clone()

138
    # Call the reshape_and_cache kernel.
139
    cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
140
                                slot_mapping, "auto")
Woosuk Kwon's avatar
Woosuk Kwon committed
141

142
143
    # Run the reference implementation.
    reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
144
    block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
145
146
147
    block_indicies = block_indicies.cpu().tolist()
    block_offsets = slot_mapping % block_size
    block_offsets = block_offsets.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
148
    for i in range(num_tokens):
149
150
        block_idx = block_indicies[i]
        block_offset = block_offsets[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
151
        cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
152
        cloned_value_cache[block_idx, :, :, block_offset] = value[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155

    assert torch.allclose(key_cache, cloned_key_cache)
    assert torch.allclose(value_cache, cloned_value_cache)