test_prefix_caching.py 7.67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""Compare the with and without prefix caching.

Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
6

7
8
from __future__ import annotations

9
10
import pytest

11
12
from tests.conftest import VllmRunner
from tests.core.utils import SchedulerProxy, create_dummy_prompt
13
from vllm import SamplingParams, TokensPrompt
14
15
from vllm.core.scheduler import Scheduler
from vllm.engine.llm_engine import LLMEngine
16
from vllm.platforms import current_platform
17
from vllm.utils import STR_BACKEND_ENV_VAR
18

19
20
from ..models.utils import check_outputs_equal

21
22

@pytest.fixture(scope="function", autouse=True)
23
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
24
25
26
    """
    This module relies on V0 internals, so set VLLM_USE_V1=0.
    """
27
28
29
    with monkeypatch.context() as m:
        m.setenv('VLLM_USE_V1', '0')
        yield
30
31


32
MODELS = [
33
    "distilbert/distilgpt2",
34
35
]

36
37
38
39
40
41
42
43
UNSTABLE_PROMPT_SEQUENCE = [
    ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([3] * 1),
    ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([5] * 50),
    ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([6] * 95),
    ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([7] * 174),
    ([0] * 588) + ([8] * 1539),
]

44

45
46
47
48
49
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
50
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
51
@pytest.mark.parametrize("block_size", [16])
52
53
54
55
56
57
58
59
60
def test_mixed_requests(
    hf_runner,
    vllm_runner,
    example_prompts,
    model: str,
    backend: str,
    dtype: str,
    max_tokens: int,
    cached_position: int,
61
    enable_chunked_prefill: bool,
62
    block_size: int,
63
    monkeypatch: pytest.MonkeyPatch,
64
65
66
) -> None:
    """
    Test the case when some sequences have the prefix cache hit
67
    and the others don't. The cached position determines where
68
69
    the sequence is at among the batch of prefills.
    """
70
71
72
73
    if backend == "FLASHINFER" and current_platform.is_rocm():
        pytest.skip("Flashinfer does not support ROCm/HIP.")
    if backend == "XFORMERS" and current_platform.is_rocm():
        pytest.skip("Xformers does not support ROCm/HIP.")
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, backend)

        with hf_runner(model, dtype=dtype) as hf_model:
            hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

        cached_prompt = example_prompts[cached_position]
        with vllm_runner(
                model,
                dtype=dtype,
                enable_prefix_caching=True,
                enable_chunked_prefill=enable_chunked_prefill,
                block_size=block_size,
        ) as vllm_model:
            # Run the first prompt so the cache is populated
            vllm_outputs = vllm_model.generate_greedy([cached_prompt],
                                                      max_tokens)

            # Run all the promopts
            greedy_params = SamplingParams(temperature=0.0,
                                           max_tokens=max_tokens)
            req_outputs = vllm_model.model.generate(example_prompts,
                                                    greedy_params)

            # Verify number of cached tokens
            for i in range(len(req_outputs)):
                if i == cached_position:
                    expected_num_cached_tokens = (
                        len(req_outputs[i].prompt_token_ids) //
                        block_size) * block_size
                else:
                    expected_num_cached_tokens = 0
                assert (req_outputs[i].num_cached_tokens ==
                        expected_num_cached_tokens)

            vllm_outputs = [(
                output.prompt_token_ids + list(output.outputs[0].token_ids),
                output.prompt + output.outputs[0].text,
            ) for output in req_outputs]

        check_outputs_equal(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=vllm_outputs,
            name_0="hf",
            name_1="vllm",
        )
120
121
122
123
124
125


@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
def test_unstable_prompt_sequence(
    vllm_runner,
    backend: str,
126
    monkeypatch: pytest.MonkeyPatch,
127
) -> None:
128
129
130
131
132

    if backend == "FLASHINFER" and current_platform.is_rocm():
        pytest.skip("Flashinfer does not support ROCm/HIP.")
    if backend == "XFORMERS" and current_platform.is_rocm():
        pytest.skip("Xformers does not support ROCm/HIP.")
133
134
135
136
137
138
139
140
141
142
143
144
    with monkeypatch.context() as m:
        m.setenv(STR_BACKEND_ENV_VAR, backend)

        with vllm_runner(
                "Qwen/Qwen2.5-0.5B-Instruct",
                enable_chunked_prefill=True,
                enable_prefix_caching=True,
                max_model_len=4096,
        ) as vllm_model:
            for prompt in UNSTABLE_PROMPT_SEQUENCE:
                vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
                                    SamplingParams(max_tokens=1))
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230


@pytest.mark.parametrize("model", MODELS)
def test_fully_cached_prefill_needs_uncached_token(model):
    block_size = 16
    max_num_batched_tokens = 16
    num_output_tokens = 5
    # Make a vllm engine
    runner = VllmRunner(
        model_name=model,
        gpu_memory_utilization=0.7,
        enable_chunked_prefill=True,
        enforce_eager=True,
        enable_prefix_caching=True,
        block_size=block_size,
        max_num_batched_tokens=max_num_batched_tokens,
        max_num_seqs=max_num_batched_tokens,
    )
    engine: LLMEngine = runner.model.llm_engine

    scheduler: Scheduler = SchedulerProxy(engine.scheduler[0])  # type: ignore
    engine.scheduler[0] = scheduler

    # SeqA
    seqA_tokens = list(range(2 * block_size))
    seqA, seq_groupA = create_dummy_prompt(
        request_id="0",
        prompt_tokens=seqA_tokens,
        max_tokens=num_output_tokens,
        block_size=block_size,
    )

    scheduler.add_seq_group(seq_groupA)

    assert seqA.data.get_num_computed_tokens() == 0

    # Prefill seqA
    while not seqA.is_finished():
        engine.step()

    # seqB
    seqB_tokens = [t + 1 for t in seqA_tokens]  # shift by 1
    seqB, seq_groupB = create_dummy_prompt(
        request_id="1",
        prompt_tokens=seqB_tokens,
        max_tokens=num_output_tokens,
        block_size=block_size,
    )

    # seqC is the same as seqA
    seqC, seq_groupC = create_dummy_prompt(
        request_id="2",
        prompt_tokens=seqA_tokens,
        max_tokens=num_output_tokens,
        block_size=block_size,
    )

    scheduler.add_seq_group(seq_groupB)
    scheduler.add_seq_group(seq_groupC)

    # Even seqC is fully cached, it should not be prefilled since we
    # require at least 1 uncached token.
    engine.step()

    sched_metas, sched_out, _ = scheduler.last_schedule_ret()
    assert len(sched_out.scheduled_seq_groups) == 1
    assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
            seq_groupB.request_id)
    assert (sched_out.scheduled_seq_groups[0].token_chunk_size ==
            max_num_batched_tokens)

    # When seqB is finished, seqC could be prefilled.
    while not seqB.is_finished():
        engine.step()
        sched_metas, sched_out, _ = scheduler.last_schedule_ret()
        assert len(sched_out.scheduled_seq_groups) == 1
        assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
                seq_groupB.request_id)

    engine.step()
    sched_metas, sched_out, _ = scheduler.last_schedule_ret()
    assert len(sched_out.scheduled_seq_groups) == 1
    assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
            seq_groupC.request_id)
    assert sched_out.scheduled_seq_groups[0].token_chunk_size == len(
        seqA_tokens)