"docs/vscode:/vscode.git/clone" did not exist on "3ec8c25cd07c4a3d747b846ece8e305a7fb44349"
test_llm_engine.py 3.08 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import random
4
from typing import Optional
5

6
7
8
9
import pytest

from vllm import LLM, SamplingParams

10
11
MODEL = "facebook/opt-125m"
DTYPE = "half"
12
13


14
15
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
    """Set up VllmRunner instance."""
16
    monkeypatch.setenv("VLLM_USE_V1", "1")
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    return vllm_runner(
        MODEL,
        dtype=DTYPE,
        max_model_len=128,
        enforce_eager=True,
        enable_prefix_caching=apc,
        gpu_memory_utilization=0.5,
    )


@pytest.fixture(
    # Function scope decouples tests & allows
    # env var adjustment via monkeypatch
    scope="function",
    # Prefix caching
    params=[False, True])
def vllm_model(vllm_runner, request, monkeypatch):
    """VllmRunner test fixture parameterized by APC True/False."""
    with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model:
        yield vllm_model


@pytest.fixture(scope="function")
def vllm_model_apc(vllm_runner, monkeypatch):
    """VllmRunner test fixture with APC."""
    with _vllm_model(True, vllm_runner, monkeypatch) as vllm_model:
        yield vllm_model


def _get_test_sampling_params(
47
    prompt_list: list[str],
48
    seed: Optional[int] = 42,
49
) -> tuple[list[SamplingParams], list[int]]:
50
51
52
    """Generate random sampling params for a batch."""

    def get_mostly_n_gt1() -> int:
53
        r"""Mostly n \in [2,20], ~1/3 n=1"""
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        x = random.randint(0, 28)
        if x < 10:
            return 1
        else:
            return x - 8

    n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
    # High temperature to maximize the chance of unique completions
    return [
        SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
        for n in n_list
    ], n_list


def test_parallel_sampling(vllm_model, example_prompts) -> None:
    """Test passes if parallel sampling `n>1` yields `n` unique completions.
    
    Args:
      vllm_model: VllmRunner instance under test.
      example_prompt: test fixture providing prompts for testing.
    """
    sampling_params_list, n_list = _get_test_sampling_params(example_prompts)
    model: LLM = vllm_model.model
    outputs = model.generate(example_prompts, sampling_params_list)

    # Validate each request response
    for out, n in zip(outputs, n_list):
81
        completion_counts: dict[str, int] = {}
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        # Assert correct number of completions
        assert len(out.outputs) == n, (
            f"{len(out.outputs)} completions; {n} expected.")
        for idx in range(n):
            comp = out.outputs[idx]
            # Assert correct completion indices
            assert comp.index == idx, (f"Index {comp.index}; expected {idx}.")
            text = comp.text
            completion_counts[text] = completion_counts.get(text, 0) + 1
        # Assert unique completions
        if len(completion_counts) != n:
            repeats = {
                txt: num
                for (txt, num) in completion_counts.items() if num > 1
            }
            raise AssertionError(
                f"{len(completion_counts)} unique completions; expected"
                f" {n}. Repeats: {repeats}")