test_detokenize.py 15.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Generator
from typing import Any, Optional
6

7
import pytest
8
9
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
10

11
from vllm.inputs import token_inputs
12
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
13
from vllm.transformers_utils.detokenizer import Detokenizer
14
from vllm.transformers_utils.tokenizer import get_tokenizer
15
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
16
17
18
19
20
21
22
23
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
                                        IncrementalDetokenizer,
                                        SlowIncrementalDetokenizer)

SPECIAL_TOKS_TRUTH = [
    "Some text with adjacent special tokens                <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>",  # noqa
]
24
25

TRUTH = [
26
27
    "Hello here, this is a simple test",
    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",  # noqa
28
29
30
31
32
33
    "我很感谢你的热情",
    # Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
    # for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
    # incomplete UTF-8 characters
    # see https://github.com/vllm-project/vllm/pull/9625
    "ပုံပြင်လေးပြောပြပါ်",
34
35
] + SPECIAL_TOKS_TRUTH

36
37
38
39
40
41
42
43
44
TOKENIZERS = [
    "facebook/opt-125m",
    "gpt2",
    "bigcode/tiny_starcoder_py",
    "EleutherAI/gpt-j-6b",
    "EleutherAI/pythia-70m",
    "bigscience/bloom-560m",
    "mosaicml/mpt-7b",
    "tiiuae/falcon-7b",
45
    "meta-llama/Llama-3.2-1B-Instruct",
46
    "codellama/CodeLlama-7b-hf",
47
    "mistralai/Pixtral-12B-2409",
48
49
50
]


51
52
53
54
55
56
57
58
59
60
61
62
63
def _run_incremental_decode(tokenizer,
                            all_input_ids,
                            skip_special_tokens: bool,
                            starting_index: int,
                            spaces_between_special_tokens: bool = True,
                            fast: Optional[bool] = None):

    prompt_token_ids = all_input_ids[:starting_index]

    params = SamplingParams(
        skip_special_tokens=skip_special_tokens,
        spaces_between_special_tokens=spaces_between_special_tokens,
    )
64
65
66
67
68
69
70
71
    request = EngineCoreRequest(request_id="",
                                prompt_token_ids=prompt_token_ids,
                                mm_features=None,
                                sampling_params=params,
                                pooling_params=None,
                                eos_token_id=None,
                                arrival_time=0.0,
                                lora_request=None,
72
73
                                cache_salt=None,
                                data_parallel_rank=None)
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    if fast is None:
        detokenizer = IncrementalDetokenizer.from_new_request(
            tokenizer, request)
    elif fast:
        detokenizer = FastIncrementalDetokenizer(tokenizer, request)
    else:
        detokenizer = SlowIncrementalDetokenizer(tokenizer, request)

    output_text = ""
    for i, token_id in enumerate(all_input_ids[starting_index:]):
        detokenizer.update([token_id], False)
        finished = i == len(all_input_ids) - 1
        output_text += detokenizer.get_next_output_text(finished, delta=True)

    return output_text, detokenizer.output_token_ids
90
91


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@pytest.fixture
def tokenizer(tokenizer_name):
    return (MistralTokenizer.from_pretrained(tokenizer_name)
            if "mistral" in tokenizer_name else
            AutoTokenizer.from_pretrained(tokenizer_name))


@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
@pytest.mark.parametrize(
    "truth",
    [
        # Burmese text triggers an edge-case where tokens may map to bytes with
        # incomplete UTF-8 characters
        "ပုံပြင်လေးပြောပြပါ",
        # Using "URGENCY" since "CY" has token id 130282
        "URGENCY🌶️",
    ])
def test_mistral_edge_case(tokenizer, truth):
    """Test for a specific edge cases with V3-Tekken MistralTokenizer.

    See https://github.com/vllm-project/vllm/pull/9625
    """
    starting_index = 0
    all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids

117
118
119
120
121
    decoded_text, out_ids = _run_incremental_decode(
        tokenizer,
        all_input_ids,
        skip_special_tokens=True,
        starting_index=starting_index)
122
    assert decoded_text == truth
123
    assert out_ids == all_input_ids[starting_index:]
124
125
126
127
128
129


@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
    if "mistral" in tokenizer_name:
        yield (
130
            True if request.param else
131
132
            pytest.skip("mistral doesn't support skip_special_tokens=False"))
    else:
133
        yield bool(request.param)
134
135


136
@pytest.mark.parametrize("truth", TRUTH)
137
@pytest.mark.parametrize("with_prompt", [True, False])
138
139
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
                          spaces_between_special_tokens, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    if skip_special_tokens and not spaces_between_special_tokens:
        pytest.skip()

    if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
        # Fix up inconsistency in fast/slow tokenizer behaviour.
        tokenizer.add_special_tokens({
            "additional_special_tokens": [
                at for at in
                tokenizer._tokenizer.get_added_tokens_decoder().values()
                if at.special
            ]
        })

    extra_decode_args = {} if not isinstance(tokenizer,  PreTrainedTokenizer) \
        else {"spaces_between_special_tokens": spaces_between_special_tokens}

    truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
    if tokenizer.bos_token_id is not None:
        truth_tokens.insert(0, tokenizer.bos_token_id)
    truth_tokens.append(tokenizer.eos_token_id)

    new_truth = tokenizer.decode(truth_tokens,
                                 skip_special_tokens=skip_special_tokens,
                                 **extra_decode_args)

172
    if with_prompt:
173
174
175
176
177
178
179
180
        num_prompt_tokens = len(
            tokenizer(truth[:len(truth) // 2],
                      add_special_tokens=False).input_ids)
        if tokenizer.bos_token_id is not None:
            num_prompt_tokens += 1

        prompt_input_ids = truth_tokens[:num_prompt_tokens]
        generated_input_ids = truth_tokens[num_prompt_tokens:]
181
182
183
        all_input_ids = prompt_input_ids + generated_input_ids
        starting_index = len(prompt_input_ids)
        prompt = tokenizer.decode(prompt_input_ids,
184
185
186
187
                                  skip_special_tokens=skip_special_tokens,
                                  **extra_decode_args)

        generated = new_truth[len(prompt):]
188
    else:
189
        generated = new_truth
190
        starting_index = 0
191
        all_input_ids = truth_tokens
192

193
    decoded_text, out_ids = _run_incremental_decode(
194
195
196
        tokenizer,
        all_input_ids,
        skip_special_tokens=skip_special_tokens,
197
198
199
        starting_index=starting_index,
        spaces_between_special_tokens=spaces_between_special_tokens,
        fast=fast)
200

201
    assert decoded_text == generated
202
    assert out_ids == all_input_ids[starting_index:]
203

204
205
206
207
208
209
210
211

@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    decoded_text, out_ids = _run_incremental_decode(
212
        tokenizer, [len(tokenizer)],
213
214
215
216
        skip_special_tokens=True,
        starting_index=0,
        spaces_between_special_tokens=True,
        fast=fast)
217
218

    assert decoded_text == ''
219
    assert out_ids == [len(tokenizer)]
220

221
222
223

@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
224
225
    tokenizer = get_tokenizer(
        tokenizer_name,
226
        tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
227
228
229
230
        trust_remote_code=False,
        revision=None,
    )

231
    return Detokenizer(tokenizer)
232
233
234
235


@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
236
                                       tokenizer) -> list[int]:
237
    return tokenizer(complete_sequence, add_special_tokens=False).input_ids
238
239
240


def create_sequence(prompt_token_ids=None):
241
    prompt_token_ids = prompt_token_ids or []
242
243
    return Sequence(
        seq_id=0,
244
        inputs=token_inputs(prompt_token_ids),
245
246
247
248
249
        block_size=16,
    )


def create_dummy_logprobs(
250
        complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
251
252
253
254
255
256
    return [{
        token_id: Logprob(logprob=0.0),
        token_id + 1: Logprob(logprob=0.1)
    } for token_id in complete_sequence_token_ids]


257
def create_dummy_prompt_logprobs(
258
259
        complete_sequence_token_ids: list[int]
) -> list[Optional[dict[int, Any]]]:
260
    # logprob for the first prompt token is None.
261
    logprobs: list[Optional[dict[int, Any]]] = [None]
262
263
264
265
    logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
    return logprobs


266
267
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
268
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
269
def test_decode_sequence_logprobs(complete_sequence: str,
270
                                  complete_sequence_token_ids: list[int],
271
272
273
274
275
276
277
278
279
                                  detokenizer: Detokenizer,
                                  skip_special_tokens: bool):
    """Verify Detokenizer decodes logprobs correctly."""
    sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
                                     logprobs=2)

    # Run sequentially.
    seq = create_sequence()
    dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
280
281
    sequential_logprobs_text_chosen_token: list[str] = []
    sequential_logprobs_text_other_token: list[str] = []
282
283
284
285
286
287
288
289
290
291
292
293
294
    for new_token, logprobs in zip(complete_sequence_token_ids,
                                   dummy_logprobs):
        seq.append_token_id(new_token, logprobs)
        detokenizer.decode_sequence_inplace(seq, sampling_params)
        sequential_logprobs_text_chosen_token.append(
            seq.output_logprobs[-1][new_token].decoded_token)
        sequential_logprobs_text_other_token.append(
            seq.output_logprobs[-1][new_token + 1].decoded_token)
    sequential_result = seq.output_text

    assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
    assert sequential_result != "".join(sequential_logprobs_text_other_token)

295
    if not skip_special_tokens:
296
297
298
299
300
301
302
303
        # Text for logprobs for the chosen token should be the same as the
        # generated text. Note that this will only be true if we skip
        # special tokens.
        assert sequential_result == complete_sequence


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
304
305
def test_decode_prompt_logprobs(complete_sequence: str,
                                complete_sequence_token_ids: list[int],
306
                                detokenizer: Detokenizer):
307
308
309
310
311

    # We want to use skip_special_tokens=False here but Mistral tokenizers
    # don't support that.
    if complete_sequence not in SPECIAL_TOKS_TRUTH:
        skip_special_tokens = True
312
    elif not isinstance(detokenizer.tokenizer, MistralTokenizer):
313
314
315
316
317
        skip_special_tokens = False
    else:
        pytest.skip("MistralTokenizers don't support "
                    "skip_special_tokens=False")
        return
318
    """Verify Detokenizer decodes prompt logprobs correctly."""
319
    sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
320
321
322
323
324
325
326
327
                                     prompt_logprobs=1)

    # Run sequentially.
    seq = create_sequence(complete_sequence_token_ids)
    seq_group = SequenceGroup(request_id="1",
                              seqs=[seq],
                              sampling_params=sampling_params,
                              arrival_time=0.0)
328
329
330
331
332
    dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
    detokenizer.decode_prompt_logprobs_inplace(seq_group,
                                               dummy_logprobs,
                                               position_offset=0)
    # First logprob is None.
333
    decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
334
        1:]  # type: ignore
335

336
337
    # decoded_prompt_logprobs doesn't contain the first token.
    token_ids = complete_sequence_token_ids
338
    tokenizer = detokenizer.tokenizer
339
340
341
342
    text_full = tokenizer.decode(token_ids,
                                 skip_special_tokens=skip_special_tokens)
    text_first = tokenizer.decode(token_ids[0],
                                  skip_special_tokens=skip_special_tokens)
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    text = text_full[len(text_first):]

    # Text for logprobs for the chosen token should be the same as the
    # prompt text. Note that the first logprob is None.
    assert text == "".join([
        logprobs[token_id].decoded_token
        for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
    ])
    assert text != "".join([
        logprobs[token_id + 1].decoded_token
        for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
    ])


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
def test_decode_prompt_logprobs_chunked_prefill(
    vllm_runner,
    model,
    chunked_prefill_token_size: int,
    example_prompts,
364
    monkeypatch,
365
):
366
367
368
369
    # VLLM V1 does not use incremental detokenization for
    # prompt logprobs, so this test strategy is irrelevant.
    monkeypatch.setenv("VLLM_USE_V1", "0")

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    max_num_seqs = 256
    enable_chunked_prefill = False
    max_num_batched_tokens = None
    if chunked_prefill_token_size != -1:
        enable_chunked_prefill = True
        max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
        max_num_batched_tokens = chunked_prefill_token_size

    with vllm_runner(model,
                     dtype="half",
                     max_logprobs=5,
                     gpu_memory_utilization=0.5,
                     enable_chunked_prefill=enable_chunked_prefill,
                     max_num_batched_tokens=max_num_batched_tokens,
                     max_num_seqs=max_num_seqs) as vllm_model:

        vllm_sampling_params = SamplingParams(max_tokens=10,
                                              logprobs=5,
                                              prompt_logprobs=5,
                                              temperature=0.0)
390
        vllm_results = vllm_model.llm.generate(
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
            example_prompts, sampling_params=vllm_sampling_params)

        for idx, result in enumerate(vllm_results):
            assert result.prompt_logprobs is not None
            assert result.prompt_logprobs[0] is None

            # Compared detokenized prompts ids to original prompt.
            generated_string = ""
            for (prompt_token,
                 prompt_logprobs) in zip(result.prompt_token_ids[1:],
                                         result.prompt_logprobs[1:]):
                # prompt_logprobs is a dict of the token_id: logprob
                # We select the token_id corresponding to the actual prompt
                # Decoded token in the detokenized string corresponding to this
                # prompt token.
                generated_string += prompt_logprobs[prompt_token].decoded_token

            assert generated_string == example_prompts[idx], (
                "Detokenized prompt logprobs do not match original prompt")