"vscode:/vscode.git/clone" did not exist on "8958217ad5a6830c4d911e5f15e6eb791df337b6"
test_stop_strings.py 4.12 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Optional

5
import os
6
7
8
import pytest

from vllm import LLM, SamplingParams, envs
9
from ..utils import models_path_prefix
10

11
MODEL = os.path.join(models_path_prefix, "meta-llama/llama-2-7b-hf")
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
MAX_TOKENS = 200


def _test_stopping(llm: LLM,
                   expected_output: str,
                   expected_reason: Any,
                   stop: Optional[list[str]] = None,
                   stop_token_ids: Optional[list[int]] = None,
                   include_in_output: bool = False) -> None:
    output = llm.generate(
        "A story about vLLM:\n",
        SamplingParams(
            temperature=0.0,
            max_tokens=MAX_TOKENS,
            stop=stop,
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_in_output,
        ))[0].outputs[0]

    assert output is not None
    assert output.text == expected_output
    assert output.stop_reason == expected_reason


def _set_async_mode(llm, is_async):
    llm.llm_engine.scheduler[0].use_async_output_proc = is_async


def _stop_basic(llm):
    _test_stopping(llm,
                   stop=["."],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer organization",
                   expected_reason=".")

    _test_stopping(llm,
                   stop=["."],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organization.",
                   expected_reason=".")


def _stop_multi_tokens(llm):
    _test_stopping(
        llm,
        stop=["group of peo", "short"],
        include_in_output=False,
        expected_output="VLLM is a 100% volunteer organization. We are a ",
        expected_reason="group of peo")

    _test_stopping(
        llm,
        stop=["group of peo", "short"],
        include_in_output=True,
        expected_output=
        "VLLM is a 100% volunteer organization. We are a group of peo",
        expected_reason="group of peo")


def _stop_partial_token(llm):
    _test_stopping(llm,
                   stop=["gani"],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer or",
                   expected_reason="gani")

    _test_stopping(llm,
                   stop=["gani"],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organi",
                   expected_reason="gani")


def _stop_token_id(llm):
    # token id 13013 => " organization"

    _test_stopping(llm,
                   stop_token_ids=[13013],
                   include_in_output=False,
                   expected_output="VLLM is a 100% volunteer",
                   expected_reason=13013)

    _test_stopping(llm,
                   stop_token_ids=[13013],
                   include_in_output=True,
                   expected_output="VLLM is a 100% volunteer organization",
                   expected_reason=13013)


@pytest.mark.skip_global_cleanup
def test_stop_strings():
    # If V0, must set enforce_eager=False since we use
    # async output processing below.
    vllm_model = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1)

    if envs.VLLM_USE_V1:
        _stop_basic(vllm_model)
    else:
        _set_async_mode(vllm_model, True)
        _stop_basic(vllm_model)

        _set_async_mode(vllm_model, False)
        _stop_basic(vllm_model)

    if envs.VLLM_USE_V1:
        _stop_multi_tokens(vllm_model)
    else:
        _set_async_mode(vllm_model, True)
        _stop_multi_tokens(vllm_model)

        _set_async_mode(vllm_model, False)
        _stop_multi_tokens(vllm_model)

    if envs.VLLM_USE_V1:
        _stop_partial_token(vllm_model)
    else:
        _set_async_mode(vllm_model, True)
        _stop_partial_token(vllm_model)

        _set_async_mode(vllm_model, False)
        _stop_partial_token(vllm_model)

    if envs.VLLM_USE_V1:
        # FIXME: this does not respect include_in_output=False
        # _stop_token_id(vllm_model)
        pass
    else:
        _set_async_mode(vllm_model, True)
        _stop_token_id(vllm_model)

        _set_async_mode(vllm_model, False)
143
        _stop_token_id(vllm_model)