test_stop_strings.py 2.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

from typing import Any, Optional

import pytest

8
from vllm import LLM, SamplingParams
9
10
11
12
13

MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200


14
15
16
17
18
19
20
21
def _test_stopping(
    llm: LLM,
    expected_output: str,
    expected_reason: Any,
    stop: Optional[list[str]] = None,
    stop_token_ids: Optional[list[int]] = None,
    include_in_output: bool = False,
) -> None:
22
23
24
25
26
27
28
29
    output = llm.generate(
        "A story about vLLM:\n",
        SamplingParams(
            temperature=0.0,
            max_tokens=MAX_TOKENS,
            stop=stop,
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_in_output,
30
31
        ),
    )[0].outputs[0]
32
33
34
35
36
37
38

    assert output is not None
    assert output.text == expected_output
    assert output.stop_reason == expected_reason


def _stop_basic(llm):
39
40
41
42
43
44
45
    _test_stopping(
        llm,
        stop=["."],
        include_in_output=False,
        expected_output="VLLM is a 100% volunteer organization",
        expected_reason=".",
    )
46

47
48
49
50
51
52
53
    _test_stopping(
        llm,
        stop=["."],
        include_in_output=True,
        expected_output="VLLM is a 100% volunteer organization.",
        expected_reason=".",
    )
54
55
56
57
58
59
60
61


def _stop_multi_tokens(llm):
    _test_stopping(
        llm,
        stop=["group of peo", "short"],
        include_in_output=False,
        expected_output="VLLM is a 100% volunteer organization. We are a ",
62
63
        expected_reason="group of peo",
    )
64
65
66
67
68

    _test_stopping(
        llm,
        stop=["group of peo", "short"],
        include_in_output=True,
69
70
71
        expected_output="VLLM is a 100% volunteer organization. We are a group of peo",
        expected_reason="group of peo",
    )
72
73
74


def _stop_partial_token(llm):
75
76
77
78
79
80
81
    _test_stopping(
        llm,
        stop=["gani"],
        include_in_output=False,
        expected_output="VLLM is a 100% volunteer or",
        expected_reason="gani",
    )
82

83
84
85
86
87
88
89
    _test_stopping(
        llm,
        stop=["gani"],
        include_in_output=True,
        expected_output="VLLM is a 100% volunteer organi",
        expected_reason="gani",
    )
90
91
92
93
94


def _stop_token_id(llm):
    # token id 13013 => " organization"

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    _test_stopping(
        llm,
        stop_token_ids=[13013],
        include_in_output=False,
        expected_output="VLLM is a 100% volunteer",
        expected_reason=13013,
    )

    _test_stopping(
        llm,
        stop_token_ids=[13013],
        include_in_output=True,
        expected_output="VLLM is a 100% volunteer organization",
        expected_reason=13013,
    )
110
111
112
113


@pytest.mark.skip_global_cleanup
def test_stop_strings():
114
    llm = LLM(MODEL, enforce_eager=True)
115

116
117
118
119
120
    _stop_basic(llm)
    _stop_multi_tokens(llm)
    _stop_partial_token(llm)
    # FIXME: this does not respect include_in_output=False
    # _stop_token_id(llm)