test_stop_strings.py 5.24 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Any, Optional
4
5
6
7
8
9
10
11

import pytest

from vllm import CompletionOutput, LLMEngine, SamplingParams

MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200

12
13
IS_ASYNC = False

14
15
16

@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
17
18
    with vllm_runner(MODEL) as vllm_model:
        yield vllm_model
19
20


21
22
23
def _test_stopping(llm_engine: LLMEngine,
                   expected_output: str,
                   expected_reason: Any,
24
25
                   stop: Optional[list[str]] = None,
                   stop_token_ids: Optional[list[int]] = None,
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
                   include_in_output: bool = False,
                   use_async_output_proc: bool = False) -> None:
    llm_engine.add_request(
        "id", "A story about vLLM:\n",
        SamplingParams(
            temperature=0.0,
            max_tokens=MAX_TOKENS,
            stop=stop,
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_in_output,
        ), None)

    output: Optional[CompletionOutput] = None
    output_text = ""
    stop_reason = None

    if use_async_output_proc:
        llm_engine.step()

    while llm_engine.has_unfinished_requests():
        (request_output, ) = llm_engine.step()
        (output, ) = request_output.outputs

        # Ensure we don't backtrack
        assert output.text.startswith(output_text)
        output_text = output.text
        stop_reason = output.stop_reason

    assert output is not None
    assert output_text == expected_output
    assert stop_reason == expected_reason


def _set_async_mode(llm_engine, is_async):
    llm_engine.scheduler[0].use_async_output_proc = is_async


def _stop_basic(llm_engine, is_async):
    _test_stopping(llm_engine,
65
66
67
                   stop=["."],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer organization",
68
69
                   expected_reason=".",
                   use_async_output_proc=is_async)
70

71
    _test_stopping(llm_engine,
72
73
74
                   stop=["."],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organization.",
75
76
                   expected_reason=".",
                   use_async_output_proc=is_async)
77
78


79
def _stop_multi_tokens(llm_engine, is_async):
80
    _test_stopping(
81
        llm_engine,
82
83
84
        stop=["group of peo", "short"],
        include_in_output=False,
        expected_output="VLLM is a 100% volunteer organization. We are a ",
85
86
        expected_reason="group of peo",
        use_async_output_proc=is_async)
87
88

    _test_stopping(
89
        llm_engine,
90
91
92
93
        stop=["group of peo", "short"],
        include_in_output=True,
        expected_output=
        "VLLM is a 100% volunteer organization. We are a group of peo",
94
95
        expected_reason="group of peo",
        use_async_output_proc=is_async)
96
97


98
99
def _stop_partial_token(llm_engine, is_async):
    _test_stopping(llm_engine,
100
101
102
                   stop=["gani"],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer or",
103
104
                   expected_reason="gani",
                   use_async_output_proc=is_async)
105

106
    _test_stopping(llm_engine,
107
108
109
                   stop=["gani"],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organi",
110
111
                   expected_reason="gani",
                   use_async_output_proc=is_async)
112
113


114
def _stop_token_id(llm_engine, is_async):
115
116
    # token id 13013 => " organization"

117
    _test_stopping(llm_engine,
118
119
120
                   stop_token_ids=[13013],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer",
121
122
                   expected_reason=13013,
                   use_async_output_proc=is_async)
123

124
    _test_stopping(llm_engine,
125
126
127
                   stop_token_ids=[13013],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organization",
128
129
                   expected_reason=13013,
                   use_async_output_proc=is_async)
130
131


132
133
134
135
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_basic(vllm_model.model.llm_engine, is_async=True)
136

137
138
    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_basic(vllm_model.model.llm_engine, is_async=False)
139
140


141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)


@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_partial_token(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_partial_token(vllm_model.model.llm_engine, is_async=False)


@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_token_id(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_token_id(vllm_model.model.llm_engine, is_async=False)