test_detokenize.py 13 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from collections.abc import Generator
from typing import Any, Optional
5

6
import pytest
7
import os
8
9
from transformers import AutoTokenizer

10
from vllm.inputs import token_inputs
11
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
12
13
from vllm.transformers_utils.detokenizer import (Detokenizer,
                                                 detokenize_incrementally)
14
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
zhuwenwen's avatar
zhuwenwen committed
15

16
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
17
from ..utils import models_path_prefix
18
19

TRUTH = [
20
21
    "Hello here, this is a simple test",
    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",  # noqa
22
23
24
25
26
27
    "我很感谢你的热情",
    # Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
    # for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
    # incomplete UTF-8 characters
    # see https://github.com/vllm-project/vllm/pull/9625
    "ပုံပြင်လေးပြောပြပါ်",
28
29
]
TOKENIZERS = [
30
31
32
33
34
35
36
37
    os.path.join(models_path_prefix, "facebook/opt-125m"),
    os.path.join(models_path_prefix, "gpt2"),
    os.path.join(models_path_prefix, "bigcode/tiny_starcoder_py"),
    os.path.join(models_path_prefix, "EleutherAI/gpt-j-6b"),
    os.path.join(models_path_prefix, "EleutherAI/pythia-70m"),
    os.path.join(models_path_prefix, "bigscience/bloom-560m"),
    os.path.join(models_path_prefix, "mosaicml/mpt-7b"),
    os.path.join(models_path_prefix, "tiiuae/falcon-7b"),
zhuwenwen's avatar
zhuwenwen committed
38
    os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
39
    os.path.join(models_path_prefix, "codellama/CodeLlama-7b-hf"),
zhuwenwen's avatar
zhuwenwen committed
40
    os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409"),
41
42
43
]


44
def _run_incremental_decode(tokenizer, all_input_ids,
45
                            skip_special_tokens: bool, starting_index: int):
46
47
48
49
    decoded_text = ""
    offset = 0
    token_offset = 0
    prev_tokens = None
50
    for i in range(starting_index, len(all_input_ids)):
51
52
53
54
55
56
        new_tokens, text, offset, token_offset = detokenize_incrementally(
            tokenizer,
            all_input_ids[:i + 1],
            prev_tokens,
            offset,
            token_offset,
57
            skip_special_tokens=skip_special_tokens)
58
59
60
61
62
63
64
65
        decoded_text += text
        if prev_tokens is None:
            prev_tokens = new_tokens
        else:
            prev_tokens += new_tokens
    return decoded_text


66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
@pytest.fixture
def tokenizer(tokenizer_name):
    return (MistralTokenizer.from_pretrained(tokenizer_name)
            if "mistral" in tokenizer_name else
            AutoTokenizer.from_pretrained(tokenizer_name))


@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
@pytest.mark.parametrize(
    "truth",
    [
        # Burmese text triggers an edge-case where tokens may map to bytes with
        # incomplete UTF-8 characters
        "ပုံပြင်လေးပြောပြပါ",
        # Using "URGENCY" since "CY" has token id 130282
        "URGENCY🌶️",
    ])
def test_mistral_edge_case(tokenizer, truth):
    """Test for a specific edge cases with V3-Tekken MistralTokenizer.

    See https://github.com/vllm-project/vllm/pull/9625
    """
    starting_index = 0
    all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids

    decoded_text = _run_incremental_decode(tokenizer,
                                           all_input_ids,
                                           skip_special_tokens=True,
                                           starting_index=starting_index)
    assert decoded_text == truth


@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
    if "mistral" in tokenizer_name:
        yield (
102
            True if request.param else
103
104
            pytest.skip("mistral doesn't support skip_special_tokens=False"))
    else:
105
        yield bool(request.param)
106
107


108
@pytest.mark.parametrize("truth", TRUTH)
109
@pytest.mark.parametrize("with_prompt", [True, False])
110
111
112
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
113
    if with_prompt:
114
        truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
115
116
117
118
119
120
121
122
123
124
        prompt_input_ids = truth_tokens[:len(truth) // 2]
        generated_input_ids = truth_tokens[len(truth) // 2:]
        all_input_ids = prompt_input_ids + generated_input_ids
        starting_index = len(prompt_input_ids)
        prompt = tokenizer.decode(prompt_input_ids,
                                  skip_special_tokens=skip_special_tokens)
        generated = truth[len(prompt):]
    else:
        generated = truth
        starting_index = 0
125
        all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
126
    if skip_special_tokens:
127
128
129
130
        if tokenizer.bos_token_id is not None:
            all_input_ids = [tokenizer.bos_token_id] + all_input_ids
            starting_index += 1
        all_input_ids = all_input_ids + [tokenizer.eos_token_id]
131

132
    decoded_text = _run_incremental_decode(
133
134
135
136
        tokenizer,
        all_input_ids,
        skip_special_tokens=skip_special_tokens,
        starting_index=starting_index)
137

138
139
    assert decoded_text == generated

140
141
142
143
144
145
146
    decoded_text = _run_incremental_decode(
        tokenizer, [len(tokenizer)],
        skip_special_tokens=skip_special_tokens,
        starting_index=starting_index)

    assert decoded_text == ''

147
148
149
150
151
152
153
154

@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
    init_kwargs = dict(
        tokenizer_id=tokenizer_name,
        enable_lora=False,
        max_num_seqs=100,
        max_input_length=None,
155
        tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        trust_remote_code=False,
        revision=None,
    )

    tokenizer_group = get_tokenizer_group(
        None,
        **init_kwargs,
    )

    return Detokenizer(tokenizer_group)


@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
170
                                       tokenizer) -> list[int]:
171
    complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
172
173
174
175
176
177
178
    return complete_sequence_token_ids


def create_sequence(prompt_token_ids=None):
    prompt_token_ids = prompt_token_ids or [1]
    return Sequence(
        seq_id=0,
179
        inputs=token_inputs(prompt_token_ids, prompt="<s>"),
180
181
182
183
184
        block_size=16,
    )


def create_dummy_logprobs(
185
        complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
186
187
188
189
190
191
    return [{
        token_id: Logprob(logprob=0.0),
        token_id + 1: Logprob(logprob=0.1)
    } for token_id in complete_sequence_token_ids]


192
def create_dummy_prompt_logprobs(
193
194
        complete_sequence_token_ids: list[int]
) -> list[Optional[dict[int, Any]]]:
195
    # logprob for the first prompt token is None.
196
    logprobs: list[Optional[dict[int, Any]]] = [None]
197
198
199
200
    logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
    return logprobs


201
202
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
203
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
204
def test_decode_sequence_logprobs(complete_sequence: str,
205
                                  complete_sequence_token_ids: list[int],
206
207
208
209
210
211
212
213
214
                                  detokenizer: Detokenizer,
                                  skip_special_tokens: bool):
    """Verify Detokenizer decodes logprobs correctly."""
    sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
                                     logprobs=2)

    # Run sequentially.
    seq = create_sequence()
    dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
215
216
    sequential_logprobs_text_chosen_token: list[str] = []
    sequential_logprobs_text_other_token: list[str] = []
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    for new_token, logprobs in zip(complete_sequence_token_ids,
                                   dummy_logprobs):
        seq.append_token_id(new_token, logprobs)
        detokenizer.decode_sequence_inplace(seq, sampling_params)
        sequential_logprobs_text_chosen_token.append(
            seq.output_logprobs[-1][new_token].decoded_token)
        sequential_logprobs_text_other_token.append(
            seq.output_logprobs[-1][new_token + 1].decoded_token)
    sequential_result = seq.output_text

    assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
    assert sequential_result != "".join(sequential_logprobs_text_other_token)

    if skip_special_tokens:
        # Text for logprobs for the chosen token should be the same as the
        # generated text. Note that this will only be true if we skip
        # special tokens.
        assert sequential_result == complete_sequence


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
239
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
240
                                detokenizer: Detokenizer):
241
    """Verify Detokenizer decodes prompt logprobs correctly."""
242
    sampling_params = SamplingParams(skip_special_tokens=True,
243
244
245
246
247
248
249
250
                                     prompt_logprobs=1)

    # Run sequentially.
    seq = create_sequence(complete_sequence_token_ids)
    seq_group = SequenceGroup(request_id="1",
                              seqs=[seq],
                              sampling_params=sampling_params,
                              arrival_time=0.0)
251
252
253
254
255
    dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
    detokenizer.decode_prompt_logprobs_inplace(seq_group,
                                               dummy_logprobs,
                                               position_offset=0)
    # First logprob is None.
256
    decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
257
        1:]  # type: ignore
258

259
260
    # decoded_prompt_logprobs doesn't contain the first token.
    token_ids = complete_sequence_token_ids
261
262
263
    tokenizer = detokenizer.get_tokenizer_for_seq(seq)
    text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
    text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    text = text_full[len(text_first):]

    # Text for logprobs for the chosen token should be the same as the
    # prompt text. Note that the first logprob is None.
    assert text == "".join([
        logprobs[token_id].decoded_token
        for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
    ])
    assert text != "".join([
        logprobs[token_id + 1].decoded_token
        for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
    ])


278
@pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "facebook/opt-125m")])
279
280
281
282
283
284
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
def test_decode_prompt_logprobs_chunked_prefill(
    vllm_runner,
    model,
    chunked_prefill_token_size: int,
    example_prompts,
285
    monkeypatch,
286
):
287
288
289
290
    # VLLM V1 does not use incremental detokenization for
    # prompt logprobs, so this test strategy is irrelevant.
    monkeypatch.setenv("VLLM_USE_V1", "0")

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    max_num_seqs = 256
    enable_chunked_prefill = False
    max_num_batched_tokens = None
    if chunked_prefill_token_size != -1:
        enable_chunked_prefill = True
        max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
        max_num_batched_tokens = chunked_prefill_token_size

    with vllm_runner(model,
                     dtype="half",
                     max_logprobs=5,
                     gpu_memory_utilization=0.5,
                     enable_chunked_prefill=enable_chunked_prefill,
                     max_num_batched_tokens=max_num_batched_tokens,
                     max_num_seqs=max_num_seqs) as vllm_model:

        vllm_sampling_params = SamplingParams(max_tokens=10,
                                              logprobs=5,
                                              prompt_logprobs=5,
                                              temperature=0.0)
        vllm_results = vllm_model.model.generate(
            example_prompts, sampling_params=vllm_sampling_params)

        for idx, result in enumerate(vllm_results):
            assert result.prompt_logprobs is not None
            assert result.prompt_logprobs[0] is None

            # Compared detokenized prompts ids to original prompt.
            generated_string = ""
            for (prompt_token,
                 prompt_logprobs) in zip(result.prompt_token_ids[1:],
                                         result.prompt_logprobs[1:]):
                # prompt_logprobs is a dict of the token_id: logprob
                # We select the token_id corresponding to the actual prompt
                # Decoded token in the detokenized string corresponding to this
                # prompt token.
                generated_string += prompt_logprobs[prompt_token].decoded_token

            assert generated_string == example_prompts[idx], (
                "Detokenized prompt logprobs do not match original prompt")