test_stop_strings.py 5.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
from typing import Any, List, Optional

import pytest

from vllm import CompletionOutput, LLMEngine, SamplingParams

MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200

10
11
IS_ASYNC = False

12
13
14

@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
15
16
    with vllm_runner(MODEL) as vllm_model:
        yield vllm_model
17
18


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def _test_stopping(llm_engine: LLMEngine,
                   expected_output: str,
                   expected_reason: Any,
                   stop: Optional[List[str]] = None,
                   stop_token_ids: Optional[List[int]] = None,
                   include_in_output: bool = False,
                   use_async_output_proc: bool = False) -> None:
    llm_engine.add_request(
        "id", "A story about vLLM:\n",
        SamplingParams(
            temperature=0.0,
            max_tokens=MAX_TOKENS,
            stop=stop,
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_in_output,
        ), None)

    output: Optional[CompletionOutput] = None
    output_text = ""
    stop_reason = None

    if use_async_output_proc:
        llm_engine.step()

    while llm_engine.has_unfinished_requests():
        (request_output, ) = llm_engine.step()
        (output, ) = request_output.outputs

        # Ensure we don't backtrack
        assert output.text.startswith(output_text)
        output_text = output.text
        stop_reason = output.stop_reason

    assert output is not None
    assert output_text == expected_output
    assert stop_reason == expected_reason


def _set_async_mode(llm_engine, is_async):
    llm_engine.scheduler[0].use_async_output_proc = is_async


def _stop_basic(llm_engine, is_async):
    _test_stopping(llm_engine,
63
64
65
                   stop=["."],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer organization",
66
67
                   expected_reason=".",
                   use_async_output_proc=is_async)
68

69
    _test_stopping(llm_engine,
70
71
72
                   stop=["."],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organization.",
73
74
                   expected_reason=".",
                   use_async_output_proc=is_async)
75
76


77
def _stop_multi_tokens(llm_engine, is_async):
78
    _test_stopping(
79
        llm_engine,
80
81
82
        stop=["group of peo", "short"],
        include_in_output=False,
        expected_output="VLLM is a 100% volunteer organization. We are a ",
83
84
        expected_reason="group of peo",
        use_async_output_proc=is_async)
85
86

    _test_stopping(
87
        llm_engine,
88
89
90
91
        stop=["group of peo", "short"],
        include_in_output=True,
        expected_output=
        "VLLM is a 100% volunteer organization. We are a group of peo",
92
93
        expected_reason="group of peo",
        use_async_output_proc=is_async)
94
95


96
97
def _stop_partial_token(llm_engine, is_async):
    _test_stopping(llm_engine,
98
99
100
                   stop=["gani"],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer or",
101
102
                   expected_reason="gani",
                   use_async_output_proc=is_async)
103

104
    _test_stopping(llm_engine,
105
106
107
                   stop=["gani"],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organi",
108
109
                   expected_reason="gani",
                   use_async_output_proc=is_async)
110
111


112
def _stop_token_id(llm_engine, is_async):
113
114
    # token id 13013 => " organization"

115
    _test_stopping(llm_engine,
116
117
118
                   stop_token_ids=[13013],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer",
119
120
                   expected_reason=13013,
                   use_async_output_proc=is_async)
121

122
    _test_stopping(llm_engine,
123
124
125
                   stop_token_ids=[13013],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organization",
126
127
                   expected_reason=13013,
                   use_async_output_proc=is_async)
128
129


130
131
132
133
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_basic(vllm_model.model.llm_engine, is_async=True)
134

135
136
    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_basic(vllm_model.model.llm_engine, is_async=False)
137
138


139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)


@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_partial_token(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_partial_token(vllm_model.model.llm_engine, is_async=False)


@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
    _set_async_mode(vllm_model.model.llm_engine, True)
    _stop_token_id(vllm_model.model.llm_engine, is_async=True)

    _set_async_mode(vllm_model.model.llm_engine, False)
    _stop_token_id(vllm_model.model.llm_engine, is_async=False)