test_detokenize.py 10.6 KB
Newer Older
1
from typing import Any, Dict, List, Optional
2

3
import pytest
4
5
from transformers import AutoTokenizer

6
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
7
8
from vllm.transformers_utils.detokenizer import (Detokenizer,
                                                 detokenize_incrementally)
9
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
10
11

TRUTH = [
12
13
14
    "Hello here, this is a simple test",
    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",  # noqa
    "我很感谢你的热情"
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
]
TOKENIZERS = [
    "facebook/opt-125m",
    "gpt2",
    "bigcode/tiny_starcoder_py",
    "EleutherAI/gpt-j-6b",
    "EleutherAI/pythia-70m",
    "bigscience/bloom-560m",
    "mosaicml/mpt-7b",
    "tiiuae/falcon-7b",
    "meta-llama/Llama-2-7b-hf",
    "codellama/CodeLlama-7b-hf",
]


30
def _run_incremental_decode(tokenizer, all_input_ids,
31
                            skip_special_tokens: bool, starting_index: int):
32
33
34
35
    decoded_text = ""
    offset = 0
    token_offset = 0
    prev_tokens = None
36
    for i in range(starting_index, len(all_input_ids)):
37
38
39
40
41
42
        new_tokens, text, offset, token_offset = detokenize_incrementally(
            tokenizer,
            all_input_ids[:i + 1],
            prev_tokens,
            offset,
            token_offset,
43
            skip_special_tokens=skip_special_tokens)
44
45
46
47
48
49
50
51
52
        decoded_text += text
        if prev_tokens is None:
            prev_tokens = new_tokens
        else:
            prev_tokens += new_tokens
    return decoded_text


@pytest.mark.parametrize("truth", TRUTH)
53
@pytest.mark.parametrize("with_prompt", [True, False])
54
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
55
@pytest.mark.parametrize("skip_special_tokens", (True, False))
56
57
def test_decode_streaming(tokenizer_id, truth, with_prompt,
                          skip_special_tokens):
58
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
59
60
61
62
63
64
65
66
67
68
69
70
71
    if with_prompt:
        truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"]
        prompt_input_ids = truth_tokens[:len(truth) // 2]
        generated_input_ids = truth_tokens[len(truth) // 2:]
        all_input_ids = prompt_input_ids + generated_input_ids
        starting_index = len(prompt_input_ids)
        prompt = tokenizer.decode(prompt_input_ids,
                                  skip_special_tokens=skip_special_tokens)
        generated = truth[len(prompt):]
    else:
        generated = truth
        starting_index = 0
        all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
72
    if skip_special_tokens:
73
74
75
76
        if tokenizer.bos_token_id is not None:
            all_input_ids = [tokenizer.bos_token_id] + all_input_ids
            starting_index += 1
        all_input_ids = all_input_ids + [tokenizer.eos_token_id]
77

78
    decoded_text = _run_incremental_decode(
79
80
81
82
        tokenizer,
        all_input_ids,
        skip_special_tokens=skip_special_tokens,
        starting_index=starting_index)
83

84
85
    assert decoded_text == generated

86
87
88
89
90
91
92
    decoded_text = _run_incremental_decode(
        tokenizer, [len(tokenizer)],
        skip_special_tokens=skip_special_tokens,
        starting_index=starting_index)

    assert decoded_text == ''

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
    init_kwargs = dict(
        tokenizer_id=tokenizer_name,
        enable_lora=False,
        max_num_seqs=100,
        max_input_length=None,
        tokenizer_mode="auto",
        trust_remote_code=False,
        revision=None,
    )

    tokenizer_group = get_tokenizer_group(
        None,
        **init_kwargs,
    )

    return Detokenizer(tokenizer_group)


@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
                                       tokenizer_name: str) -> List[int]:
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"]
    return complete_sequence_token_ids


def create_sequence(prompt_token_ids=None):
    prompt_token_ids = prompt_token_ids or [1]
    return Sequence(
        seq_id=0,
126
127
128
129
        inputs={
            "prompt": "<s>",
            "prompt_token_ids": prompt_token_ids,
        },
130
131
132
133
134
135
136
137
138
139
140
141
        block_size=16,
    )


def create_dummy_logprobs(
        complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
    return [{
        token_id: Logprob(logprob=0.0),
        token_id + 1: Logprob(logprob=0.1)
    } for token_id in complete_sequence_token_ids]


142
143
144
145
146
147
148
149
150
def create_dummy_prompt_logprobs(
        complete_sequence_token_ids: List[int]
) -> List[Optional[Dict[int, Any]]]:
    # logprob for the first prompt token is None.
    logprobs: List[Optional[Dict[int, Any]]] = [None]
    logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
    return logprobs


151
152
153
154
155
156
157
158
159
160
161
162
163
164
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False])
def test_decode_sequence_logprobs(complete_sequence: str,
                                  complete_sequence_token_ids: List[int],
                                  detokenizer: Detokenizer,
                                  skip_special_tokens: bool):
    """Verify Detokenizer decodes logprobs correctly."""
    sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
                                     logprobs=2)

    # Run sequentially.
    seq = create_sequence()
    dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
165
166
    sequential_logprobs_text_chosen_token: List[str] = []
    sequential_logprobs_text_other_token: List[str] = []
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    for new_token, logprobs in zip(complete_sequence_token_ids,
                                   dummy_logprobs):
        seq.append_token_id(new_token, logprobs)
        detokenizer.decode_sequence_inplace(seq, sampling_params)
        sequential_logprobs_text_chosen_token.append(
            seq.output_logprobs[-1][new_token].decoded_token)
        sequential_logprobs_text_other_token.append(
            seq.output_logprobs[-1][new_token + 1].decoded_token)
    sequential_result = seq.output_text

    assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
    assert sequential_result != "".join(sequential_logprobs_text_other_token)

    if skip_special_tokens:
        # Text for logprobs for the chosen token should be the same as the
        # generated text. Note that this will only be true if we skip
        # special tokens.
        assert sequential_result == complete_sequence


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
189
190
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
                                detokenizer: Detokenizer):
191
    """Verify Detokenizer decodes prompt logprobs correctly."""
192
    sampling_params = SamplingParams(skip_special_tokens=True,
193
194
195
196
197
198
199
200
                                     prompt_logprobs=1)

    # Run sequentially.
    seq = create_sequence(complete_sequence_token_ids)
    seq_group = SequenceGroup(request_id="1",
                              seqs=[seq],
                              sampling_params=sampling_params,
                              arrival_time=0.0)
201
202
203
204
205
206
207
    dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
    detokenizer.decode_prompt_logprobs_inplace(seq_group,
                                               dummy_logprobs,
                                               position_offset=0)
    # First logprob is None.
    decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
        1:]  # type: ignore
208

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    # decoded_prompt_logprobs doesn't contain the first token.
    token_ids = complete_sequence_token_ids
    tokenzier = detokenizer.get_tokenizer_for_seq(seq)
    text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
    text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
    text = text_full[len(text_first):]

    # Text for logprobs for the chosen token should be the same as the
    # prompt text. Note that the first logprob is None.
    assert text == "".join([
        logprobs[token_id].decoded_token
        for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
    ])
    assert text != "".join([
        logprobs[token_id + 1].decoded_token
        for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
    ])


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
def test_decode_prompt_logprobs_chunked_prefill(
    vllm_runner,
    model,
    chunked_prefill_token_size: int,
    example_prompts,
):
    max_num_seqs = 256
    enable_chunked_prefill = False
    max_num_batched_tokens = None
    if chunked_prefill_token_size != -1:
        enable_chunked_prefill = True
        max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
        max_num_batched_tokens = chunked_prefill_token_size

    with vllm_runner(model,
                     dtype="half",
                     max_logprobs=5,
                     gpu_memory_utilization=0.5,
                     enable_chunked_prefill=enable_chunked_prefill,
                     max_num_batched_tokens=max_num_batched_tokens,
                     max_num_seqs=max_num_seqs) as vllm_model:

        vllm_sampling_params = SamplingParams(max_tokens=10,
                                              logprobs=5,
                                              prompt_logprobs=5,
                                              temperature=0.0)
        vllm_results = vllm_model.model.generate(
            example_prompts, sampling_params=vllm_sampling_params)

        for idx, result in enumerate(vllm_results):
            assert result.prompt_logprobs is not None
            assert result.prompt_logprobs[0] is None

            # Compared detokenized prompts ids to original prompt.
            generated_string = ""
            for (prompt_token,
                 prompt_logprobs) in zip(result.prompt_token_ids[1:],
                                         result.prompt_logprobs[1:]):
                # prompt_logprobs is a dict of the token_id: logprob
                # We select the token_id corresponding to the actual prompt
                # Decoded token in the detokenized string corresponding to this
                # prompt token.
                generated_string += prompt_logprobs[prompt_token].decoded_token

            assert generated_string == example_prompts[idx], (
                "Detokenized prompt logprobs do not match original prompt")