test_detokenize.py 16.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Generator
from typing import Any, Optional
6

7
import os
8
import pytest
9
10
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
11

12
from vllm.inputs import token_inputs
13
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
14
from vllm.transformers_utils.detokenizer import Detokenizer
15
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
16
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
17
18
19
20
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
                                        IncrementalDetokenizer,
                                        SlowIncrementalDetokenizer)
21
from vllm.platforms import current_platform
22
from ..utils import models_path_prefix
23

24
25
26
SPECIAL_TOKS_TRUTH = [
    "Some text with adjacent special tokens                <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>",  # noqa
]
27
28

TRUTH = [
29
30
    "Hello here, this is a simple test",
    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",  # noqa
31
32
33
34
35
36
    "我很感谢你的热情",
    # Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
    # for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
    # incomplete UTF-8 characters
    # see https://github.com/vllm-project/vllm/pull/9625
    "ပုံပြင်လေးပြောပြပါ်",
37
38
] + SPECIAL_TOKS_TRUTH

39
TOKENIZERS = [
40
41
42
43
44
45
46
47
    os.path.join(models_path_prefix, "facebook/opt-125m"),
    os.path.join(models_path_prefix, "gpt2"),
    os.path.join(models_path_prefix, "bigcode/tiny_starcoder_py"),
    os.path.join(models_path_prefix, "EleutherAI/gpt-j-6b"),
    os.path.join(models_path_prefix, "EleutherAI/pythia-70m"),
    os.path.join(models_path_prefix, "bigscience/bloom-560m"),
    os.path.join(models_path_prefix, "mosaicml/mpt-7b"),
    os.path.join(models_path_prefix, "tiiuae/falcon-7b"),
zhuwenwen's avatar
zhuwenwen committed
48
    os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
49
    os.path.join(models_path_prefix, "codellama/CodeLlama-7b-hf"),
zhuwenwen's avatar
zhuwenwen committed
50
    os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409"),
51
52
53
]


54
55
56
57
58
59
60
61
62
63
64
65
66
def _run_incremental_decode(tokenizer,
                            all_input_ids,
                            skip_special_tokens: bool,
                            starting_index: int,
                            spaces_between_special_tokens: bool = True,
                            fast: Optional[bool] = None):

    prompt_token_ids = all_input_ids[:starting_index]

    params = SamplingParams(
        skip_special_tokens=skip_special_tokens,
        spaces_between_special_tokens=spaces_between_special_tokens,
    )
67
68
69
70
71
    request = EngineCoreRequest("",
                                prompt_token_ids,
                                None,
                                params,
                                None,
72
                                None,
73
74
                                0.0,
                                None,
75
76
                                cache_salt=None,
                                data_parallel_rank=None)
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

    if fast is None:
        detokenizer = IncrementalDetokenizer.from_new_request(
            tokenizer, request)
    elif fast:
        detokenizer = FastIncrementalDetokenizer(tokenizer, request)
    else:
        detokenizer = SlowIncrementalDetokenizer(tokenizer, request)

    output_text = ""
    for i, token_id in enumerate(all_input_ids[starting_index:]):
        detokenizer.update([token_id], False)
        finished = i == len(all_input_ids) - 1
        output_text += detokenizer.get_next_output_text(finished, delta=True)

    return output_text, detokenizer.output_token_ids
93
94


95
96
97
98
99
100
101
@pytest.fixture
def tokenizer(tokenizer_name):
    return (MistralTokenizer.from_pretrained(tokenizer_name)
            if "mistral" in tokenizer_name else
            AutoTokenizer.from_pretrained(tokenizer_name))


102
@pytest.mark.parametrize("tokenizer_name", [os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409")])
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@pytest.mark.parametrize(
    "truth",
    [
        # Burmese text triggers an edge-case where tokens may map to bytes with
        # incomplete UTF-8 characters
        "ပုံပြင်လေးပြောပြပါ",
        # Using "URGENCY" since "CY" has token id 130282
        "URGENCY🌶️",
    ])
def test_mistral_edge_case(tokenizer, truth):
    """Test for a specific edge cases with V3-Tekken MistralTokenizer.

    See https://github.com/vllm-project/vllm/pull/9625
    """
    starting_index = 0
    all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids

120
121
122
123
124
    decoded_text, out_ids = _run_incremental_decode(
        tokenizer,
        all_input_ids,
        skip_special_tokens=True,
        starting_index=starting_index)
125
    assert decoded_text == truth
126
    assert out_ids == all_input_ids[starting_index:]
127
128
129
130
131
132


@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
    if "mistral" in tokenizer_name:
        yield (
133
            True if request.param else
134
135
            pytest.skip("mistral doesn't support skip_special_tokens=False"))
    else:
136
        yield bool(request.param)
137
138


139
@pytest.mark.parametrize("truth", TRUTH)
140
@pytest.mark.parametrize("with_prompt", [True, False])
141
142
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
                          spaces_between_special_tokens, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    if skip_special_tokens and not spaces_between_special_tokens:
        pytest.skip()

    if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
        # Fix up inconsistency in fast/slow tokenizer behaviour.
        tokenizer.add_special_tokens({
            "additional_special_tokens": [
                at for at in
                tokenizer._tokenizer.get_added_tokens_decoder().values()
                if at.special
            ]
        })

    extra_decode_args = {} if not isinstance(tokenizer,  PreTrainedTokenizer) \
        else {"spaces_between_special_tokens": spaces_between_special_tokens}

    truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
    if tokenizer.bos_token_id is not None:
        truth_tokens.insert(0, tokenizer.bos_token_id)
    truth_tokens.append(tokenizer.eos_token_id)

    new_truth = tokenizer.decode(truth_tokens,
                                 skip_special_tokens=skip_special_tokens,
                                 **extra_decode_args)

175
    if with_prompt:
176
177
178
179
180
181
182
183
        num_prompt_tokens = len(
            tokenizer(truth[:len(truth) // 2],
                      add_special_tokens=False).input_ids)
        if tokenizer.bos_token_id is not None:
            num_prompt_tokens += 1

        prompt_input_ids = truth_tokens[:num_prompt_tokens]
        generated_input_ids = truth_tokens[num_prompt_tokens:]
184
185
186
        all_input_ids = prompt_input_ids + generated_input_ids
        starting_index = len(prompt_input_ids)
        prompt = tokenizer.decode(prompt_input_ids,
187
188
189
190
                                  skip_special_tokens=skip_special_tokens,
                                  **extra_decode_args)

        generated = new_truth[len(prompt):]
191
    else:
192
        generated = new_truth
193
        starting_index = 0
194
        all_input_ids = truth_tokens
195

196
    decoded_text, out_ids = _run_incremental_decode(
197
198
199
        tokenizer,
        all_input_ids,
        skip_special_tokens=skip_special_tokens,
200
201
202
        starting_index=starting_index,
        spaces_between_special_tokens=spaces_between_special_tokens,
        fast=fast)
203

204
    assert decoded_text == generated
205
    assert out_ids == all_input_ids[starting_index:]
206

207
208
209
210
211
212

@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()
213

214
    decoded_text, out_ids = _run_incremental_decode(
215
        tokenizer, [len(tokenizer)],
216
217
218
219
        skip_special_tokens=True,
        starting_index=0,
        spaces_between_special_tokens=True,
        fast=fast)
220
221

    assert decoded_text == ''
222
    assert out_ids == [len(tokenizer)]
223

224
225
226

@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
227
    tokenizer_group = TokenizerGroup(
228
229
230
231
        tokenizer_id=tokenizer_name,
        enable_lora=False,
        max_num_seqs=100,
        max_input_length=None,
232
        tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
233
234
235
236
237
238
239
240
241
        trust_remote_code=False,
        revision=None,
    )

    return Detokenizer(tokenizer_group)


@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
242
                                       tokenizer) -> list[int]:
243
    return tokenizer(complete_sequence, add_special_tokens=False).input_ids
244
245
246


def create_sequence(prompt_token_ids=None):
247
    prompt_token_ids = prompt_token_ids or []
248
249
    return Sequence(
        seq_id=0,
250
        inputs=token_inputs(prompt_token_ids),
251
        block_size=16 if not current_platform.is_rocm() else 64,
252
253
254
255
    )


def create_dummy_logprobs(
256
        complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
257
258
259
260
261
262
    return [{
        token_id: Logprob(logprob=0.0),
        token_id + 1: Logprob(logprob=0.1)
    } for token_id in complete_sequence_token_ids]


263
def create_dummy_prompt_logprobs(
264
265
        complete_sequence_token_ids: list[int]
) -> list[Optional[dict[int, Any]]]:
266
    # logprob for the first prompt token is None.
267
    logprobs: list[Optional[dict[int, Any]]] = [None]
268
269
270
271
    logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
    return logprobs


272
273
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
274
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
275
def test_decode_sequence_logprobs(complete_sequence: str,
276
                                  complete_sequence_token_ids: list[int],
277
278
279
280
281
282
283
284
285
                                  detokenizer: Detokenizer,
                                  skip_special_tokens: bool):
    """Verify Detokenizer decodes logprobs correctly."""
    sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
                                     logprobs=2)

    # Run sequentially.
    seq = create_sequence()
    dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
286
287
    sequential_logprobs_text_chosen_token: list[str] = []
    sequential_logprobs_text_other_token: list[str] = []
288
289
290
291
292
293
294
295
296
297
298
299
300
    for new_token, logprobs in zip(complete_sequence_token_ids,
                                   dummy_logprobs):
        seq.append_token_id(new_token, logprobs)
        detokenizer.decode_sequence_inplace(seq, sampling_params)
        sequential_logprobs_text_chosen_token.append(
            seq.output_logprobs[-1][new_token].decoded_token)
        sequential_logprobs_text_other_token.append(
            seq.output_logprobs[-1][new_token + 1].decoded_token)
    sequential_result = seq.output_text

    assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
    assert sequential_result != "".join(sequential_logprobs_text_other_token)

301
    if not skip_special_tokens:
302
303
304
305
306
307
308
309
        # Text for logprobs for the chosen token should be the same as the
        # generated text. Note that this will only be true if we skip
        # special tokens.
        assert sequential_result == complete_sequence


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
310
311
def test_decode_prompt_logprobs(complete_sequence: str,
                                complete_sequence_token_ids: list[int],
312
                                detokenizer: Detokenizer):
313
314
315
316
317
318
319
320
321
322
323
324

    # We want to use skip_special_tokens=False here but Mistral tokenizers
    # don't support that.
    if complete_sequence not in SPECIAL_TOKS_TRUTH:
        skip_special_tokens = True
    elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
                        MistralTokenizer):
        skip_special_tokens = False
    else:
        pytest.skip("MistralTokenizers don't support "
                    "skip_special_tokens=False")
        return
325
    """Verify Detokenizer decodes prompt logprobs correctly."""
326
    sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
327
328
329
330
331
332
333
334
                                     prompt_logprobs=1)

    # Run sequentially.
    seq = create_sequence(complete_sequence_token_ids)
    seq_group = SequenceGroup(request_id="1",
                              seqs=[seq],
                              sampling_params=sampling_params,
                              arrival_time=0.0)
335
336
337
338
339
    dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
    detokenizer.decode_prompt_logprobs_inplace(seq_group,
                                               dummy_logprobs,
                                               position_offset=0)
    # First logprob is None.
340
    decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
341
        1:]  # type: ignore
342

343
344
    # decoded_prompt_logprobs doesn't contain the first token.
    token_ids = complete_sequence_token_ids
345
    tokenizer = detokenizer.get_tokenizer_for_seq(seq)
346
347
348
349
    text_full = tokenizer.decode(token_ids,
                                 skip_special_tokens=skip_special_tokens)
    text_first = tokenizer.decode(token_ids[0],
                                  skip_special_tokens=skip_special_tokens)
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    text = text_full[len(text_first):]

    # Text for logprobs for the chosen token should be the same as the
    # prompt text. Note that the first logprob is None.
    assert text == "".join([
        logprobs[token_id].decoded_token
        for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
    ])
    assert text != "".join([
        logprobs[token_id + 1].decoded_token
        for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
    ])


364
@pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "facebook/opt-125m")])
365
366
367
368
369
370
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
def test_decode_prompt_logprobs_chunked_prefill(
    vllm_runner,
    model,
    chunked_prefill_token_size: int,
    example_prompts,
371
    monkeypatch,
372
):
373
374
375
376
    # VLLM V1 does not use incremental detokenization for
    # prompt logprobs, so this test strategy is irrelevant.
    monkeypatch.setenv("VLLM_USE_V1", "0")

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    max_num_seqs = 256
    enable_chunked_prefill = False
    max_num_batched_tokens = None
    if chunked_prefill_token_size != -1:
        enable_chunked_prefill = True
        max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
        max_num_batched_tokens = chunked_prefill_token_size

    with vllm_runner(model,
                     dtype="half",
                     max_logprobs=5,
                     gpu_memory_utilization=0.5,
                     enable_chunked_prefill=enable_chunked_prefill,
                     max_num_batched_tokens=max_num_batched_tokens,
                     max_num_seqs=max_num_seqs) as vllm_model:

        vllm_sampling_params = SamplingParams(max_tokens=10,
                                              logprobs=5,
                                              prompt_logprobs=5,
                                              temperature=0.0)
397
        vllm_results = vllm_model.llm.generate(
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
            example_prompts, sampling_params=vllm_sampling_params)

        for idx, result in enumerate(vllm_results):
            assert result.prompt_logprobs is not None
            assert result.prompt_logprobs[0] is None

            # Compared detokenized prompts ids to original prompt.
            generated_string = ""
            for (prompt_token,
                 prompt_logprobs) in zip(result.prompt_token_ids[1:],
                                         result.prompt_logprobs[1:]):
                # prompt_logprobs is a dict of the token_id: logprob
                # We select the token_id corresponding to the actual prompt
                # Decoded token in the detokenized string corresponding to this
                # prompt token.
                generated_string += prompt_logprobs[prompt_token].decoded_token

            assert generated_string == example_prompts[idx], (
416
                "Detokenized prompt logprobs do not match original prompt")