"docs/vscode:/vscode.git/clone" did not exist on "657061fdced8a33a60c1b09f5da2525de9da8f03"
test_stop_strings.py 5.25 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
from typing import Any, List, Optional

import pytest

from vllm import CompletionOutput, LLMEngine, SamplingParams

MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200

12
13
IS_ASYNC = False

14
15
16

@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
17
18
    with vllm_runner(MODEL) as vllm_model:
        yield vllm_model
19
20


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def _test_stopping(llm_engine: LLMEngine,
                   expected_output: str,
                   expected_reason: Any,
                   stop: Optional[List[str]] = None,
                   stop_token_ids: Optional[List[int]] = None,
                   include_in_output: bool = False,
                   use_async_output_proc: bool = False) -> None:
    llm_engine.add_request(
        "id", "A story about vLLM:\n",
        SamplingParams(
            temperature=0.0,
            max_tokens=MAX_TOKENS,
            stop=stop,
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_in_output,
        ), None)

    output: Optional[CompletionOutput] = None
    output_text = ""
    stop_reason = None

    if use_async_output_proc:
        llm_engine.step()

    while llm_engine.has_unfinished_requests():
        (request_output, ) = llm_engine.step()
        (output, ) = request_output.outputs

        # Ensure we don't backtrack
        assert output.text.startswith(output_text)
        output_text = output.text
        stop_reason = output.stop_reason

    assert output is not None
    assert output_text == expected_output
    assert stop_reason == expected_reason


def _set_async_mode(llm_engine, is_async):
    llm_engine.scheduler[0].use_async_output_proc = is_async


def _stop_basic(llm_engine, is_async):
    _test_stopping(llm_engine,
65
66
67
                   stop=["."],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer organization",
68
69
                   expected_reason=".",
                   use_async_output_proc=is_async)
70

71
    _test_stopping(llm_engine,
72
73
74
                   stop=["."],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organization.",
75
76
                   expected_reason=".",
                   use_async_output_proc=is_async)
77
78


79
def _stop_multi_tokens(llm_engine, is_async):
80
    _test_stopping(
81
        llm_engine,
82
83
84
        stop=["group of peo", "short"],
        include_in_output=False,
        expected_output="VLLM is a 100% volunteer organization. We are a ",
85
86
        expected_reason="group of peo",
        use_async_output_proc=is_async)
87
88

    _test_stopping(
89
        llm_engine,
90
91
92
93
        stop=["group of peo", "short"],
        include_in_output=True,
        expected_output=
        "VLLM is a 100% volunteer organization. We are a group of peo",
94
95
        expected_reason="group of peo",
        use_async_output_proc=is_async)
96
97


98
99
def _stop_partial_token(llm_engine, is_async):
    _test_stopping(llm_engine,
100
101
102
                   stop=["gani"],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer or",
103
104
                   expected_reason="gani",
                   use_async_output_proc=is_async)
105

106
    _test_stopping(llm_engine,
107
108
109
                   stop=["gani"],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organi",
110
111
                   expected_reason="gani",
                   use_async_output_proc=is_async)
112
113


114
def _stop_token_id(llm_engine, is_async):
115
116
    # token id 13013 => " organization"

117
    _test_stopping(llm_engine,
118
119
120
                   stop_token_ids=[13013],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer",
121
122
                   expected_reason=13013,
                   use_async_output_proc=is_async)
123

124
    _test_stopping(llm_engine,
125
126
127
                   stop_token_ids=[13013],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organization",
128
129
                   expected_reason=13013,
                   use_async_output_proc=is_async)
130
131


132
133
134
135
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_basic(vllm_model.model.llm_engine, is_async=True)
136

137
138
    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_basic(vllm_model.model.llm_engine, is_async=False)
139
140


141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)


@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_partial_token(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_partial_token(vllm_model.model.llm_engine, is_async=False)


@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_token_id(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_token_id(vllm_model.model.llm_engine, is_async=False)