test_no_bad_words.py 6.33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Make sure bad_words works.

Run `pytest tests/samplers/test_no_bad_words.py`.

"""
8

9
from typing import Optional
10

11
import pytest
12
13
14
15
16
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams


17
@pytest.fixture(autouse=True)
18
19
def v1(monkeypatch):
    """Only run on vLLM v1."""
20
    monkeypatch.setenv("VLLM_USE_V1", "1")
21
22


23
def _generate(
24
    llm: LLM,
25
26
27
    prompt: str,
    num_prompt_tokens: int,
    temperature: float = 0,
28
29
    bad_words: Optional[list[str]] = None,
) -> list[int]:
30
31
32
33
34
35
    sampling_params = SamplingParams(
        temperature=temperature,
        bad_words=bad_words,
    )

    # [([output_token_ids, ], [output_text, ]), ]
36
    output = llm.generate([prompt], sampling_params=sampling_params)
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

    output_token_ids = output[0][0][0][num_prompt_tokens:]
    # [0] first (and only) request output
    # [0] token_ids (not text)
    # [0] first (and only) output completion

    return output_token_ids


class TestOneTokenBadWord:
    MODEL = "TheBloke/Llama-2-7B-fp16"

    PROMPT = "Hi! How are"
    TARGET_TOKEN = "you"

    def setup_method(self, method):
53
54
55
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.MODEL, add_prefix_space=True
        )
56
57

        self.num_prompt_tokens = len(self._encode(self.PROMPT))
58
59
60
        self.target_token_id = self._encode(
            self.TARGET_TOKEN, add_special_tokens=False
        )[0]
61
62
63
64
65
66

    def test_one_token_bad_word(self, vllm_runner):
        with vllm_runner(self.MODEL) as llm:
            output_token_ids = self._generate(llm)
            assert output_token_ids[0] == self.target_token_id

67
            output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN])
68
69
            assert self.target_token_id not in output_token_ids

70
    def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]:
71
        return _generate(
72
            llm=llm,
73
74
75
76
77
            prompt=self.PROMPT,
            num_prompt_tokens=self.num_prompt_tokens,
            bad_words=bad_words,
        )

78
79
    def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]:
        return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids
80
81
82
83


class TestTwoTokenBadWord:
    # Another model (with a different tokenizer behaviour)
84
    MODEL = "distilbert/distilgpt2"
85
86
87
88
89
90
91

    PROMPT = "How old are you? I am 10"
    TARGET_TOKEN1 = "years"
    TARGET_TOKEN2 = "old"
    NEIGHBOUR_TOKEN2 = "older"

    def setup_method(self, method):
92
93
94
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.MODEL, add_prefix_space=True
        )
95
96

        self.num_prompt_tokens = len(self._encode(self.PROMPT))
97
98
99
100
101
102
103
104
105
        self.target_token_id1 = self._encode(
            self.TARGET_TOKEN1, add_special_tokens=False
        )[0]
        self.target_token_id2 = self._encode(
            self.TARGET_TOKEN2, add_special_tokens=False
        )[0]
        self.neighbour_token_id2 = self._encode(
            self.NEIGHBOUR_TOKEN2, add_special_tokens=False
        )[0]
106
107

    def test_two_token_bad_word(self, vllm_runner):
108
        with vllm_runner(self.MODEL, dtype="half") as llm:
109
110
            output_token_ids = self._generate(llm)
            assert output_token_ids[:2] == [
111
112
                self.target_token_id1,
                self.target_token_id2,
113
114
            ]

115
            output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN1])
116
117
            assert self.target_token_id1 not in output_token_ids

118
            output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN2])
119
120
121
122
            assert output_token_ids[0] == self.target_token_id1
            assert self.target_token_id2 not in output_token_ids

            output_token_ids = self._generate(
123
124
                llm, bad_words=[f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}"]
            )
125
126
            assert output_token_ids[0] == self.target_token_id1
            assert output_token_ids[:2] != [
127
128
                self.target_token_id1,
                self.target_token_id2,
129
130
            ]
            assert not self._contains(
131
132
                output_token_ids, [self.target_token_id1, self.target_token_id2]
            )
133
134
            # Model dependent behaviour
            assert output_token_ids[:2] == [
135
136
                self.target_token_id1,
                self.neighbour_token_id2,
137
138
139
140
141
            ]

            output_token_ids = self._generate(
                llm,
                bad_words=[
142
143
144
145
                    f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}",
                    f"{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}",
                ],
            )
146
147
            assert output_token_ids[0] == self.target_token_id1
            assert output_token_ids[:2] != [
148
149
                self.target_token_id1,
                self.target_token_id2,
150
151
            ]
            assert not self._contains(
152
153
                output_token_ids, [self.target_token_id1, self.target_token_id2]
            )
154
            assert output_token_ids[:2] != [
155
156
                self.target_token_id1,
                self.neighbour_token_id2,
157
158
            ]
            assert not self._contains(
159
160
161
162
163
164
165
                output_token_ids, [self.target_token_id1, self.neighbour_token_id2]
            )
            assert (self.target_token_id2 in output_token_ids) or (
                self.neighbour_token_id2 in output_token_ids
            )

    def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]:
166
        return _generate(
167
            llm=llm,
168
169
170
171
172
173
            prompt=self.PROMPT,
            num_prompt_tokens=self.num_prompt_tokens,
            bad_words=bad_words,
        )

    @staticmethod
174
    def _contains(sequence: list[int], subsequence: list[int]) -> bool:
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        searched = False

        for start in range(len(sequence)):
            end = start + len(subsequence)
            current_subsequence = sequence[start:end]

            if len(current_subsequence) < len(subsequence):
                continue

            searched = True

            assert len(current_subsequence) == len(subsequence)

            if current_subsequence == subsequence:
                return True

        assert searched, "All subsequences did not match in length..."

        return False

195
196
    def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]:
        return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids