test_no_bad_words.py 6.64 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
"""Make sure bad_words works.

Run `pytest tests/samplers/test_no_bad_words.py`.

"""
7
from typing import Optional
8

9
import pytest
10
11
12
13
14
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams


15
16
17
18
19
20
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
    """We can run both engines for this test."""
    pass


21
22
23
24
25
def _generate(
    model: LLM,
    prompt: str,
    num_prompt_tokens: int,
    temperature: float = 0,
26
27
    bad_words: Optional[list[str]] = None,
) -> list[int]:
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    sampling_params = SamplingParams(
        temperature=temperature,
        bad_words=bad_words,
    )

    # [([output_token_ids, ], [output_text, ]), ]
    output = model.generate([prompt], sampling_params=sampling_params)

    output_token_ids = output[0][0][0][num_prompt_tokens:]
    # [0] first (and only) request output
    # [0] token_ids (not text)
    # [0] first (and only) output completion

    return output_token_ids


class TestOneTokenBadWord:
    MODEL = "TheBloke/Llama-2-7B-fp16"

    PROMPT = "Hi! How are"
    TARGET_TOKEN = "you"

    def setup_method(self, method):
        self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
                                                       add_prefix_space=True)

        self.num_prompt_tokens = len(self._encode(self.PROMPT))
        self.target_token_id = self._encode(self.TARGET_TOKEN,
                                            add_special_tokens=False)[0]

    def test_one_token_bad_word(self, vllm_runner):
        with vllm_runner(self.MODEL) as llm:
            output_token_ids = self._generate(llm)
            assert output_token_ids[0] == self.target_token_id

            output_token_ids = self._generate(llm,
                                              bad_words=[self.TARGET_TOKEN])
            assert self.target_token_id not in output_token_ids

    def _generate(self,
                  model: LLM,
69
                  bad_words: Optional[list[str]] = None) -> list[int]:
70
71
72
73
74
75
76
77
78
        return _generate(
            model=model,
            prompt=self.PROMPT,
            num_prompt_tokens=self.num_prompt_tokens,
            bad_words=bad_words,
        )

    def _encode(self,
                prompt: str,
79
                add_special_tokens: bool = True) -> list[int]:
80
81
82
83
84
85
        return self.tokenizer(prompt,
                              add_special_tokens=add_special_tokens).input_ids


class TestTwoTokenBadWord:
    # Another model (with a different tokenizer behaviour)
86
    MODEL = "distilbert/distilgpt2"
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    PROMPT = "How old are you? I am 10"
    TARGET_TOKEN1 = "years"
    TARGET_TOKEN2 = "old"
    NEIGHBOUR_TOKEN2 = "older"

    def setup_method(self, method):
        self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
                                                       add_prefix_space=True)

        self.num_prompt_tokens = len(self._encode(self.PROMPT))
        self.target_token_id1 = self._encode(self.TARGET_TOKEN1,
                                             add_special_tokens=False)[0]
        self.target_token_id2 = self._encode(self.TARGET_TOKEN2,
                                             add_special_tokens=False)[0]
        self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2,
                                                add_special_tokens=False)[0]

    def test_two_token_bad_word(self, vllm_runner):
106
        with vllm_runner(self.MODEL, dtype="half") as llm:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
            output_token_ids = self._generate(llm)
            assert output_token_ids[:2] == [
                self.target_token_id1, self.target_token_id2
            ]

            output_token_ids = self._generate(llm,
                                              bad_words=[self.TARGET_TOKEN1])
            assert self.target_token_id1 not in output_token_ids

            output_token_ids = self._generate(llm,
                                              bad_words=[self.TARGET_TOKEN2])
            assert output_token_ids[0] == self.target_token_id1
            assert self.target_token_id2 not in output_token_ids

            output_token_ids = self._generate(
                llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}'])
            assert output_token_ids[0] == self.target_token_id1
            assert output_token_ids[:2] != [
                self.target_token_id1, self.target_token_id2
            ]
            assert not self._contains(
                output_token_ids,
                [self.target_token_id1, self.target_token_id2])
            # Model dependent behaviour
            assert output_token_ids[:2] == [
                self.target_token_id1, self.neighbour_token_id2
            ]

            output_token_ids = self._generate(
                llm,
                bad_words=[
                    f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}',
                    f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}'
                ])
            assert output_token_ids[0] == self.target_token_id1
            assert output_token_ids[:2] != [
                self.target_token_id1, self.target_token_id2
            ]
            assert not self._contains(
                output_token_ids,
                [self.target_token_id1, self.target_token_id2])
            assert output_token_ids[:2] != [
                self.target_token_id1, self.neighbour_token_id2
            ]
            assert not self._contains(
                output_token_ids,
                [self.target_token_id1, self.neighbour_token_id2])
            assert ((self.target_token_id2 in output_token_ids)
                    or (self.neighbour_token_id2 in output_token_ids))

    def _generate(self,
                  model: LLM,
159
                  bad_words: Optional[list[str]] = None) -> list[int]:
160
161
162
163
164
165
166
167
        return _generate(
            model=model,
            prompt=self.PROMPT,
            num_prompt_tokens=self.num_prompt_tokens,
            bad_words=bad_words,
        )

    @staticmethod
168
    def _contains(sequence: list[int], subsequence: list[int]) -> bool:
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        searched = False

        for start in range(len(sequence)):
            end = start + len(subsequence)
            current_subsequence = sequence[start:end]

            if len(current_subsequence) < len(subsequence):
                continue

            searched = True

            assert len(current_subsequence) == len(subsequence)

            if current_subsequence == subsequence:
                return True

        assert searched, "All subsequences did not match in length..."

        return False

    def _encode(self,
                prompt: str,
191
                add_special_tokens: bool = True) -> list[int]:
192
193
        return self.tokenizer(prompt,
                              add_special_tokens=add_special_tokens).input_ids