test_prefix_caching.py 4.66 KB
Newer Older
1
2
3
4
"""Compare the with and without prefix caching.

Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
5
6
from typing import List

7
8
import pytest

9
from tests.kernels.utils import override_backend_env_variable
10
from vllm.block import PhysicalTokenBlock
11
from vllm.core.block_manager_v1 import CachedBlockAllocator
12
13
from vllm.utils import Device

14
15
16
17
18
19
from ..models.utils import check_outputs_equal

MODELS = [
    "facebook/opt-125m",
]

20
21
22
23
24
25

@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [16])
def test_block_allocator(
    block_size: int,
    num_blocks: int,
26
):
27
    block_hash = 1
28
    block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
29

30
31
    # Allocate two PysicalTokenBlocks with the same hash and check
    # that they are the same PhysicalTokenBlock
32
33
34
35
36
    first_block = block_allocator.allocate(block_hash, 0)
    second_block = block_allocator.allocate(block_hash, 0)
    assert (first_block == second_block)
    assert (second_block.ref_count == 2)

37
38
39
    # Check metric: 1 hit of 2 queries
    assert block_allocator.get_prefix_cache_hit_rate() == 0.5

40
41
    # Free the first_block and confirm that the ref_count is correctly
    # decremented on the second block
42
43
44
45
46
47
    block_allocator.free(first_block)
    assert (second_block.ref_count == 1)

    # Free the second block
    block_allocator.free(second_block)

48
49
    # Reallocate the first block and confirm that, even after the block
    # had its ref_count go to 0, we still get the same block back
50
51
52
53
    first_block = block_allocator.allocate(block_hash, 0)
    assert (first_block == second_block)
    assert (first_block.block_hash == block_hash)

54
55
56
57
    # Allocate one more time to get 3/4 hit rate for easy checking
    block_allocator.allocate(block_hash, 0)
    assert block_allocator.get_prefix_cache_hit_rate() == 0.75

58
59
60
61

@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
    block_size = 16
62
    block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
63
    blocks: List[PhysicalTokenBlock] = []
64
65
66
67
68
69
70
71
72

    for i in range(num_blocks):
        # use i as the block_hash
        blocks.append(block_allocator.allocate(i, 0))

    #Free all blocks
    for block in blocks:
        block_allocator.free(block)

73
74
    # Allocate a new block and confirm that it's the first block freed.
    # I.E The Least Recently Used block
75
76
77
78
79
80
81
82
83
84
85
    new_block_hash = block_size
    new_block = block_allocator.allocate(new_block_hash, 0)
    assert (new_block == blocks[0])
    assert (new_block.block_hash == new_block_hash)

    # Reallocate the second in blocks to remove it from the free list
    realloc_block_hash = 1
    realloc_block = block_allocator.allocate(realloc_block_hash, 0)
    assert (realloc_block == blocks[realloc_block_hash])
    assert (realloc_block.block_hash == realloc_block_hash)

86
87
    # Allocate a new block and confirm that it's not the realloc_block,
    # since the realloc_block shouldn't be in the free list
88
89
90
91
92
    new_block_hash = block_size + 1
    new_block = block_allocator.allocate(new_block_hash, 0)
    assert (realloc_block != new_block)
    assert (new_block.block_hash == new_block_hash)
    assert (new_block.block_number == 2)
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
def test_mixed_requests(
    hf_runner,
    vllm_runner,
    example_prompts,
    model: str,
    backend: str,
    dtype: str,
    max_tokens: int,
    cached_position: int,
    use_v2_block_manager: bool,
    monkeypatch,
) -> None:
    """
    Test the case when some sequences have the prefix cache hit
    and the others don't. The cached position determines where 
    the sequence is at among the batch of prefills.
    """
    override_backend_env_variable(monkeypatch, backend)

    with hf_runner(model, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

    cached_prompt = example_prompts[cached_position]
    with vllm_runner(
            model,
            dtype=dtype,
            enable_prefix_caching=True,
            use_v2_block_manager=use_v2_block_manager,
    ) as vllm_model:
        # Run the first prompt so the cache is populated
        vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)

        # Run all the promopts
        vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

    check_outputs_equal(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )