test_no_bad_words.py 6.19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Make sure bad_words works.

Run `pytest tests/samplers/test_no_bad_words.py`.

"""
8

9
from typing import Optional
10
11
12
13
14
15
16

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams


def _generate(
17
    llm: LLM,
18
19
20
    prompt: str,
    num_prompt_tokens: int,
    temperature: float = 0,
21
22
    bad_words: Optional[list[str]] = None,
) -> list[int]:
23
24
25
26
27
28
    sampling_params = SamplingParams(
        temperature=temperature,
        bad_words=bad_words,
    )

    # [([output_token_ids, ], [output_text, ]), ]
29
    output = llm.generate([prompt], sampling_params=sampling_params)
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

    output_token_ids = output[0][0][0][num_prompt_tokens:]
    # [0] first (and only) request output
    # [0] token_ids (not text)
    # [0] first (and only) output completion

    return output_token_ids


class TestOneTokenBadWord:
    MODEL = "TheBloke/Llama-2-7B-fp16"

    PROMPT = "Hi! How are"
    TARGET_TOKEN = "you"

    def setup_method(self, method):
46
47
48
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.MODEL, add_prefix_space=True
        )
49
50

        self.num_prompt_tokens = len(self._encode(self.PROMPT))
51
52
53
        self.target_token_id = self._encode(
            self.TARGET_TOKEN, add_special_tokens=False
        )[0]
54
55
56
57
58
59

    def test_one_token_bad_word(self, vllm_runner):
        with vllm_runner(self.MODEL) as llm:
            output_token_ids = self._generate(llm)
            assert output_token_ids[0] == self.target_token_id

60
            output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN])
61
62
            assert self.target_token_id not in output_token_ids

63
    def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]:
64
        return _generate(
65
            llm=llm,
66
67
68
69
70
            prompt=self.PROMPT,
            num_prompt_tokens=self.num_prompt_tokens,
            bad_words=bad_words,
        )

71
72
    def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]:
        return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids
73
74
75
76


class TestTwoTokenBadWord:
    # Another model (with a different tokenizer behaviour)
77
    MODEL = "distilbert/distilgpt2"
78
79
80
81
82
83
84

    PROMPT = "How old are you? I am 10"
    TARGET_TOKEN1 = "years"
    TARGET_TOKEN2 = "old"
    NEIGHBOUR_TOKEN2 = "older"

    def setup_method(self, method):
85
86
87
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.MODEL, add_prefix_space=True
        )
88
89

        self.num_prompt_tokens = len(self._encode(self.PROMPT))
90
91
92
93
94
95
96
97
98
        self.target_token_id1 = self._encode(
            self.TARGET_TOKEN1, add_special_tokens=False
        )[0]
        self.target_token_id2 = self._encode(
            self.TARGET_TOKEN2, add_special_tokens=False
        )[0]
        self.neighbour_token_id2 = self._encode(
            self.NEIGHBOUR_TOKEN2, add_special_tokens=False
        )[0]
99
100

    def test_two_token_bad_word(self, vllm_runner):
101
        with vllm_runner(self.MODEL, dtype="half") as llm:
102
103
            output_token_ids = self._generate(llm)
            assert output_token_ids[:2] == [
104
105
                self.target_token_id1,
                self.target_token_id2,
106
107
            ]

108
            output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN1])
109
110
            assert self.target_token_id1 not in output_token_ids

111
            output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN2])
112
113
114
115
            assert output_token_ids[0] == self.target_token_id1
            assert self.target_token_id2 not in output_token_ids

            output_token_ids = self._generate(
116
117
                llm, bad_words=[f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}"]
            )
118
119
            assert output_token_ids[0] == self.target_token_id1
            assert output_token_ids[:2] != [
120
121
                self.target_token_id1,
                self.target_token_id2,
122
123
            ]
            assert not self._contains(
124
125
                output_token_ids, [self.target_token_id1, self.target_token_id2]
            )
126
127
            # Model dependent behaviour
            assert output_token_ids[:2] == [
128
129
                self.target_token_id1,
                self.neighbour_token_id2,
130
131
132
133
134
            ]

            output_token_ids = self._generate(
                llm,
                bad_words=[
135
136
137
138
                    f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}",
                    f"{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}",
                ],
            )
139
140
            assert output_token_ids[0] == self.target_token_id1
            assert output_token_ids[:2] != [
141
142
                self.target_token_id1,
                self.target_token_id2,
143
144
            ]
            assert not self._contains(
145
146
                output_token_ids, [self.target_token_id1, self.target_token_id2]
            )
147
            assert output_token_ids[:2] != [
148
149
                self.target_token_id1,
                self.neighbour_token_id2,
150
151
            ]
            assert not self._contains(
152
153
154
155
156
157
158
                output_token_ids, [self.target_token_id1, self.neighbour_token_id2]
            )
            assert (self.target_token_id2 in output_token_ids) or (
                self.neighbour_token_id2 in output_token_ids
            )

    def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]:
159
        return _generate(
160
            llm=llm,
161
162
163
164
165
166
            prompt=self.PROMPT,
            num_prompt_tokens=self.num_prompt_tokens,
            bad_words=bad_words,
        )

    @staticmethod
167
    def _contains(sequence: list[int], subsequence: list[int]) -> bool:
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        searched = False

        for start in range(len(sequence)):
            end = start + len(subsequence)
            current_subsequence = sequence[start:end]

            if len(current_subsequence) < len(subsequence):
                continue

            searched = True

            assert len(current_subsequence) == len(subsequence)

            if current_subsequence == subsequence:
                return True

        assert searched, "All subsequences did not match in length..."

        return False

188
189
    def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]:
        return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids