test_sampling_params_e2e.py 6.27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
import os

import pytest

from vllm import LLM, SamplingParams
8
from ...utils import models_path_prefix
9
10
11
12

if os.getenv("VLLM_USE_V1", "0") != "1":
    pytest.skip("Test package requires V1", allow_module_level=True)

13
MODEL = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B")
14
15
16
17
PROMPT = "Hello my name is Robert and I"


@pytest.fixture(scope="module")
18
def llm() -> LLM:
19
20
21
22
    # Disable prefix caching so that we can test prompt logprobs.
    # TODO remove this after https://github.com/vllm-project/vllm/pull/13949
    # is merged
    return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False)
23
24


25
def test_n_gt_1(llm):
26
27
28
    """ParallelSampling is supported."""

    params = SamplingParams(n=3)
29
    outputs = llm.generate(PROMPT, params)
30
31
32
    assert len(outputs[0].outputs) == 3


33
def test_best_of(llm):
34
35
36
37
    """Raise a ValueError since best_of is deprecated."""

    params = SamplingParams(n=2, best_of=3)
    with pytest.raises(ValueError):
38
        _ = llm.generate(PROMPT, params)
39
40


41
def test_penalties(llm):
42
43
44
45
46
47
48
49
50
51
52
    """Check that we do not get errors if applied."""

    params = SamplingParams(
        temperature=1.2,
        presence_penalty=1.2,
        frequency_penalty=1.2,
        repetition_penalty=1.2,
        min_p=0.5,
        top_p=0.5,
        top_k=3,
    )
53
    _ = llm.generate(PROMPT, params)
54
55


56
def test_stop(llm):
57
58
    """Check that we respect the stop words."""

59
    output = llm.generate(PROMPT, SamplingParams(temperature=0))
60
61
62
63
    split_text = output[0].outputs[0].text.split()

    STOP_IDX = 5
    params = SamplingParams(temperature=0, stop=split_text[STOP_IDX])
64
    output = llm.generate(PROMPT, params)
65
66
67
68
69
70
71
72
    new_split_text = output[0].outputs[0].text.split()

    # Output should not contain the stop word.
    assert len(new_split_text) == STOP_IDX

    params = SamplingParams(temperature=0,
                            stop=split_text[STOP_IDX],
                            include_stop_str_in_output=True)
73
    output = llm.generate(PROMPT, params)
74
75
76
77
78
79
    new_split_text = output[0].outputs[0].text.split()

    # Output should contain the stop word.
    assert len(new_split_text) == STOP_IDX + 1


80
def test_stop_token_ids(llm):
81
82
    """Check that we respect the stop token ids."""

83
    output = llm.generate(PROMPT, SamplingParams(temperature=0))
84
85
86
87
88
89

    stop_token_id_0 = output[0].outputs[0].token_ids[5]
    stop_token_id_1 = output[0].outputs[0].token_ids[6]

    stop_token_ids = [stop_token_id_1, stop_token_id_0]
    params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
90
    output = llm.generate(PROMPT, params)
91
92
93
94
    assert output[0].outputs[0].token_ids[-1] == stop_token_id_0

    stop_token_ids = [stop_token_id_0, stop_token_id_1]
    params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
95
    output = llm.generate(PROMPT, params)
96
97
98
    assert output[0].outputs[0].token_ids[-1] == stop_token_id_0


99
def test_detokenize_false(llm):
100
101
    """Check that detokenize=False option works."""

102
    output = llm.generate(PROMPT, SamplingParams(detokenize=False))
103
104
105
    assert len(output[0].outputs[0].token_ids) > 0
    assert len(output[0].outputs[0].text) == 0

106
    output = llm.generate(
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        PROMPT, SamplingParams(detokenize=False, logprobs=3,
                               prompt_logprobs=3))
    assert len(output[0].outputs[0].token_ids) > 0
    assert len(output[0].outputs[0].text) == 0

    prompt_logprobs = output[0].prompt_logprobs
    sampled_logprobs = output[0].outputs[0].logprobs
    assert len(prompt_logprobs) > 1
    assert len(sampled_logprobs) > 1
    for all_logprobs in (prompt_logprobs[1:], sampled_logprobs):
        for logprobs in all_logprobs:
            assert 3 <= len(logprobs) <= 4
            assert all(lp.decoded_token is None for lp in logprobs.values())


122
def test_bad_words(llm):
123
124
    """Check that we respect bad words."""

125
    output = llm.generate(PROMPT, SamplingParams(temperature=0))
126
127
128
129
    split_text = output[0].outputs[0].text.split()

    bad_words_1 = " ".join(split_text[:2])
    params = SamplingParams(temperature=0, bad_words=[bad_words_1])
130
    output = llm.generate(PROMPT, params)
131
132
133
134
135
136
    new_text = output[0].outputs[0].text
    assert bad_words_1 not in new_text

    bad_words_2 = new_text.split()[-1]
    params = SamplingParams(temperature=0,
                            bad_words=[bad_words_1, bad_words_2])
137
    output = llm.generate(PROMPT, params)
138
139
140
    new_text = output[0].outputs[0].text
    assert bad_words_1 not in new_text
    assert bad_words_2 not in new_text
141
142


143
def test_logits_processor(llm):
144
145
146
147
148
149
150
151
152
153
    """Check that we reject logits processor."""

    # This sample logits processor gives infinite score to the i-th token,
    # where i is the length of the input sequence.
    # We therefore expect the output token sequence to be [0, 1, 2, ...]
    def pick_ith(token_ids, logits):
        logits[len(token_ids)] = float("inf")
        return logits

    with pytest.raises(ValueError):
154
        _ = llm.generate(PROMPT, SamplingParams(logits_processors=[pick_ith]))
155
156


157
def test_allowed_token_ids(llm):
158
159
160
161
    """Check that we can use allowed_token_ids."""

    TOKEN_ID = 10
    allowed_token_ids = [TOKEN_ID]
162
163
    output = llm.generate(PROMPT,
                          SamplingParams(allowed_token_ids=allowed_token_ids))
164
165
    assert output[0].outputs[0].token_ids[-1] == TOKEN_ID

166
167
    # Reject empty allowed_token_ids.
    with pytest.raises(ValueError):
168
        _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[]))
169

170
171
    # Reject negative token id.
    with pytest.raises(ValueError):
172
        _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[-1]))
173
174
175

    # Reject out of vocabulary.
    with pytest.raises(ValueError):
176
        _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[10000000]))
177
178


179
def test_priority(llm):
180
181
182
183
    """Check that we reject requests with priority."""

    # Reject all allowed token ids
    with pytest.raises(ValueError):
184
        _ = llm.generate(PROMPT, priority=[1])
185
186


187
def test_seed(llm):
188
189
    """Check that seed impacts randomness."""

190
191
192
    out_1 = llm.generate(PROMPT, SamplingParams(seed=42))
    out_2 = llm.generate(PROMPT, SamplingParams(seed=42))
    out_3 = llm.generate(PROMPT, SamplingParams(seed=43))
193
194
195

    assert out_1[0].outputs[0].text == out_2[0].outputs[0].text
    assert out_1[0].outputs[0].text != out_3[0].outputs[0].text