test_stop_strings.py 5.33 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
from typing import Any, List, Optional

import pytest
6
import os
7
8

from vllm import CompletionOutput, LLMEngine, SamplingParams
9
from ..utils import models_path_prefix
10

11
MODEL = os.path.join(models_path_prefix, "meta-llama/llama-2-7b-hf")
12
13
MAX_TOKENS = 200

14
15
IS_ASYNC = False

16
17
18

@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
19
20
    with vllm_runner(MODEL) as vllm_model:
        yield vllm_model
21
22


23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def _test_stopping(llm_engine: LLMEngine,
                   expected_output: str,
                   expected_reason: Any,
                   stop: Optional[List[str]] = None,
                   stop_token_ids: Optional[List[int]] = None,
                   include_in_output: bool = False,
                   use_async_output_proc: bool = False) -> None:
    llm_engine.add_request(
        "id", "A story about vLLM:\n",
        SamplingParams(
            temperature=0.0,
            max_tokens=MAX_TOKENS,
            stop=stop,
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_in_output,
        ), None)

    output: Optional[CompletionOutput] = None
    output_text = ""
    stop_reason = None

    if use_async_output_proc:
        llm_engine.step()

    while llm_engine.has_unfinished_requests():
        (request_output, ) = llm_engine.step()
        (output, ) = request_output.outputs

        # Ensure we don't backtrack
        assert output.text.startswith(output_text)
        output_text = output.text
        stop_reason = output.stop_reason

    assert output is not None
    assert output_text == expected_output
    assert stop_reason == expected_reason


def _set_async_mode(llm_engine, is_async):
    llm_engine.scheduler[0].use_async_output_proc = is_async


def _stop_basic(llm_engine, is_async):
    _test_stopping(llm_engine,
67
68
69
                   stop=["."],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer organization",
70
71
                   expected_reason=".",
                   use_async_output_proc=is_async)
72

73
    _test_stopping(llm_engine,
74
75
76
                   stop=["."],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organization.",
77
78
                   expected_reason=".",
                   use_async_output_proc=is_async)
79
80


81
def _stop_multi_tokens(llm_engine, is_async):
82
    _test_stopping(
83
        llm_engine,
84
85
86
        stop=["group of peo", "short"],
        include_in_output=False,
        expected_output="VLLM is a 100% volunteer organization. We are a ",
87
88
        expected_reason="group of peo",
        use_async_output_proc=is_async)
89
90

    _test_stopping(
91
        llm_engine,
92
93
94
95
        stop=["group of peo", "short"],
        include_in_output=True,
        expected_output=
        "VLLM is a 100% volunteer organization. We are a group of peo",
96
97
        expected_reason="group of peo",
        use_async_output_proc=is_async)
98
99


100
101
def _stop_partial_token(llm_engine, is_async):
    _test_stopping(llm_engine,
102
103
104
                   stop=["gani"],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer or",
105
106
                   expected_reason="gani",
                   use_async_output_proc=is_async)
107

108
    _test_stopping(llm_engine,
109
110
111
                   stop=["gani"],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organi",
112
113
                   expected_reason="gani",
                   use_async_output_proc=is_async)
114
115


116
def _stop_token_id(llm_engine, is_async):
117
118
    # token id 13013 => " organization"

119
    _test_stopping(llm_engine,
120
121
122
                   stop_token_ids=[13013],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer",
123
124
                   expected_reason=13013,
                   use_async_output_proc=is_async)
125

126
    _test_stopping(llm_engine,
127
128
129
                   stop_token_ids=[13013],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organization",
130
131
                   expected_reason=13013,
                   use_async_output_proc=is_async)
132
133


134
135
136
137
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_basic(vllm_model.model.llm_engine, is_async=True)
138

139
140
    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_basic(vllm_model.model.llm_engine, is_async=False)
141
142


143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)


@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_partial_token(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_partial_token(vllm_model.model.llm_engine, is_async=False)


@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_token_id(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_token_id(vllm_model.model.llm_engine, is_async=False)