test_llm_engine.py 7.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from __future__ import annotations
4

5
import random
6
from typing import TYPE_CHECKING
7

8
9
import pytest

10
from vllm import LLM
11
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
12
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
13

14
15
16
if TYPE_CHECKING:
    from tests.conftest import VllmRunner

17
18
MODEL = "facebook/opt-125m"
DTYPE = "half"
19
20


21
22
23
24
25
26
27
def _vllm_model(
    apc: bool,
    vllm_runner: type[VllmRunner],
    monkeypatch: pytest.MonkeyPatch,
    *,
    skip_tokenizer_init: bool = False,
):
28
    """Set up VllmRunner instance."""
29
    monkeypatch.setenv("VLLM_USE_V1", "1")
30
31
32
33
34
35
36
    return vllm_runner(
        MODEL,
        dtype=DTYPE,
        max_model_len=128,
        enforce_eager=True,
        enable_prefix_caching=apc,
        gpu_memory_utilization=0.5,
37
        skip_tokenizer_init=skip_tokenizer_init,
38
39
40
41
42
43
44
45
    )


@pytest.fixture(
    # Function scope decouples tests & allows
    # env var adjustment via monkeypatch
    scope="function",
    # Prefix caching
46
47
    params=[False, True],
)
48
49
50
51
52
53
54
55
56
57
58
59
60
def vllm_model(vllm_runner, request, monkeypatch):
    """VllmRunner test fixture parameterized by APC True/False."""
    with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model:
        yield vllm_model


@pytest.fixture(scope="function")
def vllm_model_apc(vllm_runner, monkeypatch):
    """VllmRunner test fixture with APC."""
    with _vllm_model(True, vllm_runner, monkeypatch) as vllm_model:
        yield vllm_model


61
62
63
64
65
@pytest.fixture(
    # Function scope decouples tests & allows
    # env var adjustment via monkeypatch
    scope="function",
    # Prefix caching
66
67
    params=[False, True],
)
68
69
70
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
    """VllmRunner test fixture with APC."""
    with _vllm_model(
71
72
73
74
        request.param,
        vllm_runner,
        monkeypatch,
        skip_tokenizer_init=True,
75
76
77
78
    ) as vllm_model:
        yield vllm_model


79
def _get_test_sampling_params(
80
    prompt_list: list[str],
81
    seed: int | None = 42,
82
    structured_outputs: bool = False,
83
) -> tuple[list[SamplingParams], list[int]]:
84
85
86
    """Generate random sampling params for a batch."""

    def get_mostly_n_gt1() -> int:
87
        r"""Mostly n \in [2,20], ~1/3 n=1"""
88
89
90
91
92
93
94
95
96
        x = random.randint(0, 28)
        if x < 10:
            return 1
        else:
            return x - 8

    n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
    # High temperature to maximize the chance of unique completions
    return [
97
98
99
100
101
        SamplingParams(
            temperature=0.95,
            top_p=0.95,
            n=n,
            seed=seed,
102
103
104
105
106
            structured_outputs=StructuredOutputsParams(regex="[0-9]+")
            if structured_outputs
            else None,
        )
        for n in n_list
107
108
109
    ], n_list


110
111
112
113
114
115
116
117
118
def test_compatibility_with_skip_tokenizer_init(
    vllm_model_skip_tokenizer_init: VllmRunner,
    example_prompts: list[str],
):
    # Case 1: Structured output request should raise an error.
    sampling_params_list, _ = _get_test_sampling_params(
        example_prompts,
        structured_outputs=True,
    )
119
    llm: LLM = vllm_model_skip_tokenizer_init.llm
120
    with pytest.raises(ValueError):
121
        _ = llm.generate(example_prompts, sampling_params_list)
122
123


124
125
def test_parallel_sampling(vllm_model, example_prompts) -> None:
    """Test passes if parallel sampling `n>1` yields `n` unique completions.
126

127
128
129
130
131
    Args:
      vllm_model: VllmRunner instance under test.
      example_prompt: test fixture providing prompts for testing.
    """
    sampling_params_list, n_list = _get_test_sampling_params(example_prompts)
132
133
    llm: LLM = vllm_model.llm
    outputs = llm.generate(example_prompts, sampling_params_list)
134
135
136

    # Validate each request response
    for out, n in zip(outputs, n_list):
137
        completion_counts: dict[str, int] = {}
138
        # Assert correct number of completions
139
        assert len(out.outputs) == n, f"{len(out.outputs)} completions; {n} expected."
140
141
142
        for idx in range(n):
            comp = out.outputs[idx]
            # Assert correct completion indices
143
            assert comp.index == idx, f"Index {comp.index}; expected {idx}."
144
145
146
147
            text = comp.text
            completion_counts[text] = completion_counts.get(text, 0) + 1
        # Assert unique completions
        if len(completion_counts) != n:
148
            repeats = {txt: num for (txt, num) in completion_counts.items() if num > 1}
149
150
            raise AssertionError(
                f"{len(completion_counts)} unique completions; expected"
151
152
                f" {n}. Repeats: {repeats}"
            )
153
154
155
156
157
158
159
160
161
162
163
164
165


def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
    max_tokens = 100
    # Use spec decoding to test num_accepted_tokens_per_pos
    speculative_config = {
        "method": "ngram",
        "prompt_lookup_max": 5,
        "prompt_lookup_min": 3,
        "num_speculative_tokens": 5,
    }
    monkeypatch.setenv("VLLM_USE_V1", "1")
    with vllm_runner(
166
167
168
        MODEL,
        speculative_config=speculative_config,
        disable_log_stats=False,
169
    ) as vllm_model:
170
        llm: LLM = vllm_model.llm
171
        sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
172
        outputs = llm.generate(example_prompts, sampling_params)
173
174
175
176
177
178
179
180
181
182

        n_prompts = len(example_prompts)
        assert len(outputs) == n_prompts

        total_tokens = 0
        for out in outputs:
            assert len(out.outputs) == 1
            total_tokens += len(out.outputs[0].token_ids)
        assert total_tokens == max_tokens * n_prompts

183
        metrics = llm.get_metrics()
184
185
186
187
188
189
190
191
192
193
194

        def find_metric(name) -> list[Metric]:
            found = []
            for metric in metrics:
                if metric.name == name:
                    found.append(metric)
            return found

        num_requests_running = find_metric("vllm:num_requests_running")
        assert len(num_requests_running) == 1
        assert isinstance(num_requests_running[0], Gauge)
195
        assert num_requests_running[0].value == 0.0
196
197
198
199
200
201

        generation_tokens = find_metric("vllm:generation_tokens")
        assert len(generation_tokens) == 1
        assert isinstance(generation_tokens[0], Counter)
        assert generation_tokens[0].value == total_tokens

202
        request_generation_tokens = find_metric("vllm:request_generation_tokens")
203
204
205
206
207
208
209
210
        assert len(request_generation_tokens) == 1
        assert isinstance(request_generation_tokens[0], Histogram)
        assert "+Inf" in request_generation_tokens[0].buckets
        assert request_generation_tokens[0].buckets["+Inf"] == n_prompts
        assert request_generation_tokens[0].count == n_prompts
        assert request_generation_tokens[0].sum == total_tokens

        num_accepted_tokens_per_pos = find_metric(
211
212
            "vllm:spec_decode_num_accepted_tokens_per_pos"
        )
213
214
215
        assert len(num_accepted_tokens_per_pos) == 1
        assert isinstance(num_accepted_tokens_per_pos[0], Vector)
        assert len(num_accepted_tokens_per_pos[0].values) == 5
216
217
218


@pytest.mark.parametrize("model", ["meta-llama/Llama-3.2-1B-Instruct"])
219
def test_skip_tokenizer_initialization(model: str, monkeypatch: pytest.MonkeyPatch):
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    monkeypatch.setenv("VLLM_USE_V1", "1")
    # This test checks if the flag skip_tokenizer_init skips the initialization
    # of tokenizer and detokenizer. The generated output is expected to contain
    # token ids.
    llm = LLM(
        model=model,
        skip_tokenizer_init=True,
        enforce_eager=True,
    )
    sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)

    with pytest.raises(ValueError, match="cannot pass text prompts when"):
        llm.generate("abc", sampling_params)

234
235
236
    outputs = llm.generate(
        {"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params
    )
237
238
239
240
241
    assert len(outputs) > 0
    completions = outputs[0].outputs
    assert len(completions) > 0
    assert completions[0].text == ""
    assert completions[0].token_ids