test_sampling_params_e2e.py 5.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

import pytest

from vllm import LLM, SamplingParams

8
MODEL = "hmellor/tiny-random-LlamaForCausalLM"
9
10
11
12
PROMPT = "Hello my name is Robert and I"


@pytest.fixture(scope="module")
13
def llm() -> LLM:
14
    return LLM(MODEL, enforce_eager=True)
15
16


17
def test_n_gt_1(llm):
18
19
20
    """ParallelSampling is supported."""

    params = SamplingParams(n=3)
21
    outputs = llm.generate(PROMPT, params)
22
23
24
    assert len(outputs[0].outputs) == 3


25
def test_penalties(llm):
26
27
28
29
30
31
32
33
34
35
36
    """Check that we do not get errors if applied."""

    params = SamplingParams(
        temperature=1.2,
        presence_penalty=1.2,
        frequency_penalty=1.2,
        repetition_penalty=1.2,
        min_p=0.5,
        top_p=0.5,
        top_k=3,
    )
37
    _ = llm.generate(PROMPT, params)
38
39


40
def test_stop(llm):
41
42
    """Check that we respect the stop words."""

43
    output = llm.generate(PROMPT, SamplingParams(temperature=0))
44
45
46
47
    split_text = output[0].outputs[0].text.split()

    STOP_IDX = 5
    params = SamplingParams(temperature=0, stop=split_text[STOP_IDX])
48
    output = llm.generate(PROMPT, params)
49
50
51
52
53
    new_split_text = output[0].outputs[0].text.split()

    # Output should not contain the stop word.
    assert len(new_split_text) == STOP_IDX

54
55
56
    params = SamplingParams(
        temperature=0, stop=split_text[STOP_IDX], include_stop_str_in_output=True
    )
57
    output = llm.generate(PROMPT, params)
58
59
60
61
62
63
    new_split_text = output[0].outputs[0].text.split()

    # Output should contain the stop word.
    assert len(new_split_text) == STOP_IDX + 1


64
def test_stop_token_ids(llm):
65
66
    """Check that we respect the stop token ids."""

67
    output = llm.generate(PROMPT, SamplingParams(temperature=0))
68
69
70
71
72
73

    stop_token_id_0 = output[0].outputs[0].token_ids[5]
    stop_token_id_1 = output[0].outputs[0].token_ids[6]

    stop_token_ids = [stop_token_id_1, stop_token_id_0]
    params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
74
    output = llm.generate(PROMPT, params)
75
76
77
78
    assert output[0].outputs[0].token_ids[-1] == stop_token_id_0

    stop_token_ids = [stop_token_id_0, stop_token_id_1]
    params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
79
    output = llm.generate(PROMPT, params)
80
81
82
    assert output[0].outputs[0].token_ids[-1] == stop_token_id_0


83
def test_detokenize_false(llm):
84
85
    """Check that detokenize=False option works."""

86
    output = llm.generate(PROMPT, SamplingParams(detokenize=False))
87
88
89
    assert len(output[0].outputs[0].token_ids) > 0
    assert len(output[0].outputs[0].text) == 0

90
    output = llm.generate(
91
92
        PROMPT, SamplingParams(detokenize=False, logprobs=3, prompt_logprobs=3)
    )
93
94
95
96
97
98
99
100
101
102
103
104
105
    assert len(output[0].outputs[0].token_ids) > 0
    assert len(output[0].outputs[0].text) == 0

    prompt_logprobs = output[0].prompt_logprobs
    sampled_logprobs = output[0].outputs[0].logprobs
    assert len(prompt_logprobs) > 1
    assert len(sampled_logprobs) > 1
    for all_logprobs in (prompt_logprobs[1:], sampled_logprobs):
        for logprobs in all_logprobs:
            assert 3 <= len(logprobs) <= 4
            assert all(lp.decoded_token is None for lp in logprobs.values())


106
def test_bad_words(llm):
107
108
    """Check that we respect bad words."""

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    tokenizer = llm.get_tokenizer()

    def contains_bad_word(text: str, tokens: list[int], bad_word: str) -> bool:
        """Check if word appears in BOTH text and token sequence."""
        if bad_word not in text:
            return False

        for add_prefix_space in [False, True]:
            prefix = " " if add_prefix_space else ""
            bad_words_token = tokenizer.encode(
                prefix + bad_word.lstrip(), add_special_tokens=False
            )
            if not bad_words_token:
                continue
            for i in range(len(tokens) - len(bad_words_token) + 1):
                if tokens[i : i + len(bad_words_token)] == bad_words_token:
                    return True
        return False

128
    output = llm.generate(PROMPT, SamplingParams(temperature=0))
129
130
131
132
    split_text = output[0].outputs[0].text.split()

    bad_words_1 = " ".join(split_text[:2])
    params = SamplingParams(temperature=0, bad_words=[bad_words_1])
133
    output = llm.generate(PROMPT, params)
134
    new_text = output[0].outputs[0].text
135
136
    new_tokens = output[0].outputs[0].token_ids
    assert not contains_bad_word(new_text, new_tokens, bad_words_1)
137
138

    bad_words_2 = new_text.split()[-1]
139
    params = SamplingParams(temperature=0, bad_words=[bad_words_1, bad_words_2])
140
    output = llm.generate(PROMPT, params)
141
    new_text = output[0].outputs[0].text
142
143
144
    new_tokens = output[0].outputs[0].token_ids
    assert not contains_bad_word(new_text, new_tokens, bad_words_1)
    assert not contains_bad_word(new_text, new_tokens, bad_words_2)
145
146


147
def test_allowed_token_ids(llm):
148
149
150
151
    """Check that we can use allowed_token_ids."""

    TOKEN_ID = 10
    allowed_token_ids = [TOKEN_ID]
152
    output = llm.generate(PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids))
153
154
    assert output[0].outputs[0].token_ids[-1] == TOKEN_ID

155
156
    # Reject empty allowed_token_ids.
    with pytest.raises(ValueError):
157
        _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[]))
158

159
160
    # Reject negative token id.
    with pytest.raises(ValueError):
161
        _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[-1]))
162
163
164

    # Reject out of vocabulary.
    with pytest.raises(ValueError):
165
        _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[10000000]))
166
167


168
def test_seed(llm):
169
170
    """Check that seed impacts randomness."""

171
172
173
    out_1 = llm.generate(PROMPT, SamplingParams(seed=42))
    out_2 = llm.generate(PROMPT, SamplingParams(seed=42))
    out_3 = llm.generate(PROMPT, SamplingParams(seed=43))
174
175
176

    assert out_1[0].outputs[0].text == out_2[0].outputs[0].text
    assert out_1[0].outputs[0].text != out_3[0].outputs[0].text