test_prefix_caching.py 4.89 KB
Newer Older
1
2
3
4
"""Compare the with and without prefix caching.

Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
5
6
from typing import List

7
8
import pytest

9
from tests.kernels.utils import override_backend_env_variable
10
from tests.utils import check_deprecated_block_manager_usage
11
from vllm.block import PhysicalTokenBlock
12
from vllm.core.block_manager_v1 import CachedBlockAllocator
13
14
from vllm.utils import Device

15
16
17
18
19
20
from ..models.utils import check_outputs_equal

MODELS = [
    "facebook/opt-125m",
]

21

22
23
24
25
26
27
@pytest.fixture(scope="module", autouse=True)
def check_deprecated_block_manager():
    check_deprecated_block_manager_usage(
        'tests/prefix_caching/test_prefix_caching.py')


28
29
30
31
32
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [16])
def test_block_allocator(
    block_size: int,
    num_blocks: int,
33
):
34
    block_hash = 1
35
    block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
36

37
38
    # Allocate two PysicalTokenBlocks with the same hash and check
    # that they are the same PhysicalTokenBlock
39
40
41
42
43
    first_block = block_allocator.allocate(block_hash, 0)
    second_block = block_allocator.allocate(block_hash, 0)
    assert (first_block == second_block)
    assert (second_block.ref_count == 2)

44
45
46
    # Check metric: 1 hit of 2 queries
    assert block_allocator.get_prefix_cache_hit_rate() == 0.5

47
48
    # Free the first_block and confirm that the ref_count is correctly
    # decremented on the second block
49
50
51
52
53
54
    block_allocator.free(first_block)
    assert (second_block.ref_count == 1)

    # Free the second block
    block_allocator.free(second_block)

55
56
    # Reallocate the first block and confirm that, even after the block
    # had its ref_count go to 0, we still get the same block back
57
58
59
60
    first_block = block_allocator.allocate(block_hash, 0)
    assert (first_block == second_block)
    assert (first_block.block_hash == block_hash)

61
62
63
64
    # Allocate one more time to get 3/4 hit rate for easy checking
    block_allocator.allocate(block_hash, 0)
    assert block_allocator.get_prefix_cache_hit_rate() == 0.75

65
66
67
68

@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
    block_size = 16
69
    block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
70
    blocks: List[PhysicalTokenBlock] = []
71
72
73
74
75
76
77
78
79

    for i in range(num_blocks):
        # use i as the block_hash
        blocks.append(block_allocator.allocate(i, 0))

    #Free all blocks
    for block in blocks:
        block_allocator.free(block)

80
81
    # Allocate a new block and confirm that it's the first block freed.
    # I.E The Least Recently Used block
82
83
84
85
86
87
88
89
90
91
92
    new_block_hash = block_size
    new_block = block_allocator.allocate(new_block_hash, 0)
    assert (new_block == blocks[0])
    assert (new_block.block_hash == new_block_hash)

    # Reallocate the second in blocks to remove it from the free list
    realloc_block_hash = 1
    realloc_block = block_allocator.allocate(realloc_block_hash, 0)
    assert (realloc_block == blocks[realloc_block_hash])
    assert (realloc_block.block_hash == realloc_block_hash)

93
94
    # Allocate a new block and confirm that it's not the realloc_block,
    # since the realloc_block shouldn't be in the free list
95
96
97
98
99
    new_block_hash = block_size + 1
    new_block = block_allocator.allocate(new_block_hash, 0)
    assert (realloc_block != new_block)
    assert (new_block.block_hash == new_block_hash)
    assert (new_block.block_number == 2)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
def test_mixed_requests(
    hf_runner,
    vllm_runner,
    example_prompts,
    model: str,
    backend: str,
    dtype: str,
    max_tokens: int,
    cached_position: int,
    use_v2_block_manager: bool,
    monkeypatch,
) -> None:
    """
    Test the case when some sequences have the prefix cache hit
    and the others don't. The cached position determines where 
    the sequence is at among the batch of prefills.
    """
    override_backend_env_variable(monkeypatch, backend)

    with hf_runner(model, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

    cached_prompt = example_prompts[cached_position]
    with vllm_runner(
            model,
            dtype=dtype,
            enable_prefix_caching=True,
            use_v2_block_manager=use_v2_block_manager,
    ) as vllm_model:
        # Run the first prompt so the cache is populated
        vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)

        # Run all the promopts
        vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

    check_outputs_equal(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )