test_detokenize.py 16.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Generator
from typing import Any, Optional
6

7
import os
8
import pytest
9
10
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
11

12
from vllm.inputs import token_inputs
13
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
14
from vllm.transformers_utils.detokenizer import Detokenizer
15
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
16
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
17
18
19
20
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
                                        IncrementalDetokenizer,
                                        SlowIncrementalDetokenizer)
21
from vllm.platforms import current_platform
22
from ..utils import models_path_prefix
23

24
25
26
SPECIAL_TOKS_TRUTH = [
    "Some text with adjacent special tokens                <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>",  # noqa
]
27
28

TRUTH = [
29
30
    "Hello here, this is a simple test",
    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",  # noqa
31
32
33
34
35
36
    "我很感谢你的热情",
    # Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
    # for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
    # incomplete UTF-8 characters
    # see https://github.com/vllm-project/vllm/pull/9625
    "ပုံပြင်လေးပြောပြပါ်",
37
38
] + SPECIAL_TOKS_TRUTH

39
TOKENIZERS = [
40
41
42
43
44
45
46
47
    os.path.join(models_path_prefix, "facebook/opt-125m"),
    os.path.join(models_path_prefix, "gpt2"),
    os.path.join(models_path_prefix, "bigcode/tiny_starcoder_py"),
    os.path.join(models_path_prefix, "EleutherAI/gpt-j-6b"),
    os.path.join(models_path_prefix, "EleutherAI/pythia-70m"),
    os.path.join(models_path_prefix, "bigscience/bloom-560m"),
    os.path.join(models_path_prefix, "mosaicml/mpt-7b"),
    os.path.join(models_path_prefix, "tiiuae/falcon-7b"),
zhuwenwen's avatar
zhuwenwen committed
48
    os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
49
    os.path.join(models_path_prefix, "codellama/CodeLlama-7b-hf"),
zhuwenwen's avatar
zhuwenwen committed
50
    os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409"),
51
52
53
]


54
55
56
57
58
59
60
61
62
63
64
65
66
def _run_incremental_decode(tokenizer,
                            all_input_ids,
                            skip_special_tokens: bool,
                            starting_index: int,
                            spaces_between_special_tokens: bool = True,
                            fast: Optional[bool] = None):

    prompt_token_ids = all_input_ids[:starting_index]

    params = SamplingParams(
        skip_special_tokens=skip_special_tokens,
        spaces_between_special_tokens=spaces_between_special_tokens,
    )
67
68
69
70
71
72
73
    request = EngineCoreRequest("",
                                prompt_token_ids,
                                None,
                                None,
                                None,
                                params,
                                None,
74
                                None,
75
76
                                0.0,
                                None,
77
78
                                cache_salt=None,
                                data_parallel_rank=None)
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    if fast is None:
        detokenizer = IncrementalDetokenizer.from_new_request(
            tokenizer, request)
    elif fast:
        detokenizer = FastIncrementalDetokenizer(tokenizer, request)
    else:
        detokenizer = SlowIncrementalDetokenizer(tokenizer, request)

    output_text = ""
    for i, token_id in enumerate(all_input_ids[starting_index:]):
        detokenizer.update([token_id], False)
        finished = i == len(all_input_ids) - 1
        output_text += detokenizer.get_next_output_text(finished, delta=True)

    return output_text, detokenizer.output_token_ids
95
96


97
98
99
100
101
102
103
@pytest.fixture
def tokenizer(tokenizer_name):
    return (MistralTokenizer.from_pretrained(tokenizer_name)
            if "mistral" in tokenizer_name else
            AutoTokenizer.from_pretrained(tokenizer_name))


104
@pytest.mark.parametrize("tokenizer_name", [os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409")])
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@pytest.mark.parametrize(
    "truth",
    [
        # Burmese text triggers an edge-case where tokens may map to bytes with
        # incomplete UTF-8 characters
        "ပုံပြင်လေးပြောပြပါ",
        # Using "URGENCY" since "CY" has token id 130282
        "URGENCY🌶️",
    ])
def test_mistral_edge_case(tokenizer, truth):
    """Test for a specific edge cases with V3-Tekken MistralTokenizer.

    See https://github.com/vllm-project/vllm/pull/9625
    """
    starting_index = 0
    all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids

122
123
124
125
126
    decoded_text, out_ids = _run_incremental_decode(
        tokenizer,
        all_input_ids,
        skip_special_tokens=True,
        starting_index=starting_index)
127
    assert decoded_text == truth
128
    assert out_ids == all_input_ids[starting_index:]
129
130
131
132
133
134


@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
    if "mistral" in tokenizer_name:
        yield (
135
            True if request.param else
136
137
            pytest.skip("mistral doesn't support skip_special_tokens=False"))
    else:
138
        yield bool(request.param)
139
140


141
@pytest.mark.parametrize("truth", TRUTH)
142
@pytest.mark.parametrize("with_prompt", [True, False])
143
144
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
                          spaces_between_special_tokens, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    if skip_special_tokens and not spaces_between_special_tokens:
        pytest.skip()

    if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
        # Fix up inconsistency in fast/slow tokenizer behaviour.
        tokenizer.add_special_tokens({
            "additional_special_tokens": [
                at for at in
                tokenizer._tokenizer.get_added_tokens_decoder().values()
                if at.special
            ]
        })

    extra_decode_args = {} if not isinstance(tokenizer,  PreTrainedTokenizer) \
        else {"spaces_between_special_tokens": spaces_between_special_tokens}

    truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
    if tokenizer.bos_token_id is not None:
        truth_tokens.insert(0, tokenizer.bos_token_id)
    truth_tokens.append(tokenizer.eos_token_id)

    new_truth = tokenizer.decode(truth_tokens,
                                 skip_special_tokens=skip_special_tokens,
                                 **extra_decode_args)

177
    if with_prompt:
178
179
180
181
182
183
184
185
        num_prompt_tokens = len(
            tokenizer(truth[:len(truth) // 2],
                      add_special_tokens=False).input_ids)
        if tokenizer.bos_token_id is not None:
            num_prompt_tokens += 1

        prompt_input_ids = truth_tokens[:num_prompt_tokens]
        generated_input_ids = truth_tokens[num_prompt_tokens:]
186
187
188
        all_input_ids = prompt_input_ids + generated_input_ids
        starting_index = len(prompt_input_ids)
        prompt = tokenizer.decode(prompt_input_ids,
189
190
191
192
                                  skip_special_tokens=skip_special_tokens,
                                  **extra_decode_args)

        generated = new_truth[len(prompt):]
193
    else:
194
        generated = new_truth
195
        starting_index = 0
196
        all_input_ids = truth_tokens
197

198
    decoded_text, out_ids = _run_incremental_decode(
199
200
201
        tokenizer,
        all_input_ids,
        skip_special_tokens=skip_special_tokens,
202
203
204
        starting_index=starting_index,
        spaces_between_special_tokens=spaces_between_special_tokens,
        fast=fast)
205

206
    assert decoded_text == generated
207
    assert out_ids == all_input_ids[starting_index:]
208

209
210
211
212
213
214

@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()
215

216
    decoded_text, out_ids = _run_incremental_decode(
217
        tokenizer, [len(tokenizer)],
218
219
220
221
        skip_special_tokens=True,
        starting_index=0,
        spaces_between_special_tokens=True,
        fast=fast)
222
223

    assert decoded_text == ''
224
    assert out_ids == [len(tokenizer)]
225

226
227
228

@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
229
    tokenizer_group = TokenizerGroup(
230
231
232
233
        tokenizer_id=tokenizer_name,
        enable_lora=False,
        max_num_seqs=100,
        max_input_length=None,
234
        tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
235
236
237
238
239
240
241
242
243
        trust_remote_code=False,
        revision=None,
    )

    return Detokenizer(tokenizer_group)


@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
244
                                       tokenizer) -> list[int]:
245
    return tokenizer(complete_sequence, add_special_tokens=False).input_ids
246
247
248


def create_sequence(prompt_token_ids=None):
249
    prompt_token_ids = prompt_token_ids or []
250
251
    return Sequence(
        seq_id=0,
252
        inputs=token_inputs(prompt_token_ids),
253
        block_size=16 if not current_platform.is_rocm() else 64,
254
255
256
257
    )


def create_dummy_logprobs(
258
        complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
259
260
261
262
263
264
    return [{
        token_id: Logprob(logprob=0.0),
        token_id + 1: Logprob(logprob=0.1)
    } for token_id in complete_sequence_token_ids]


265
def create_dummy_prompt_logprobs(
266
267
        complete_sequence_token_ids: list[int]
) -> list[Optional[dict[int, Any]]]:
268
    # logprob for the first prompt token is None.
269
    logprobs: list[Optional[dict[int, Any]]] = [None]
270
271
272
273
    logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
    return logprobs


274
275
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
276
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
277
def test_decode_sequence_logprobs(complete_sequence: str,
278
                                  complete_sequence_token_ids: list[int],
279
280
281
282
283
284
285
286
287
                                  detokenizer: Detokenizer,
                                  skip_special_tokens: bool):
    """Verify Detokenizer decodes logprobs correctly."""
    sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
                                     logprobs=2)

    # Run sequentially.
    seq = create_sequence()
    dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
288
289
    sequential_logprobs_text_chosen_token: list[str] = []
    sequential_logprobs_text_other_token: list[str] = []
290
291
292
293
294
295
296
297
298
299
300
301
302
    for new_token, logprobs in zip(complete_sequence_token_ids,
                                   dummy_logprobs):
        seq.append_token_id(new_token, logprobs)
        detokenizer.decode_sequence_inplace(seq, sampling_params)
        sequential_logprobs_text_chosen_token.append(
            seq.output_logprobs[-1][new_token].decoded_token)
        sequential_logprobs_text_other_token.append(
            seq.output_logprobs[-1][new_token + 1].decoded_token)
    sequential_result = seq.output_text

    assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
    assert sequential_result != "".join(sequential_logprobs_text_other_token)

303
    if not skip_special_tokens:
304
305
306
307
308
309
310
311
        # Text for logprobs for the chosen token should be the same as the
        # generated text. Note that this will only be true if we skip
        # special tokens.
        assert sequential_result == complete_sequence


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
312
313
def test_decode_prompt_logprobs(complete_sequence: str,
                                complete_sequence_token_ids: list[int],
314
                                detokenizer: Detokenizer):
315
316
317
318
319
320
321
322
323
324
325
326

    # We want to use skip_special_tokens=False here but Mistral tokenizers
    # don't support that.
    if complete_sequence not in SPECIAL_TOKS_TRUTH:
        skip_special_tokens = True
    elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
                        MistralTokenizer):
        skip_special_tokens = False
    else:
        pytest.skip("MistralTokenizers don't support "
                    "skip_special_tokens=False")
        return
327
    """Verify Detokenizer decodes prompt logprobs correctly."""
328
    sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
329
330
331
332
333
334
335
336
                                     prompt_logprobs=1)

    # Run sequentially.
    seq = create_sequence(complete_sequence_token_ids)
    seq_group = SequenceGroup(request_id="1",
                              seqs=[seq],
                              sampling_params=sampling_params,
                              arrival_time=0.0)
337
338
339
340
341
    dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
    detokenizer.decode_prompt_logprobs_inplace(seq_group,
                                               dummy_logprobs,
                                               position_offset=0)
    # First logprob is None.
342
    decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
343
        1:]  # type: ignore
344

345
346
    # decoded_prompt_logprobs doesn't contain the first token.
    token_ids = complete_sequence_token_ids
347
    tokenizer = detokenizer.get_tokenizer_for_seq(seq)
348
349
350
351
    text_full = tokenizer.decode(token_ids,
                                 skip_special_tokens=skip_special_tokens)
    text_first = tokenizer.decode(token_ids[0],
                                  skip_special_tokens=skip_special_tokens)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    text = text_full[len(text_first):]

    # Text for logprobs for the chosen token should be the same as the
    # prompt text. Note that the first logprob is None.
    assert text == "".join([
        logprobs[token_id].decoded_token
        for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
    ])
    assert text != "".join([
        logprobs[token_id + 1].decoded_token
        for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
    ])


366
@pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "facebook/opt-125m")])
367
368
369
370
371
372
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
def test_decode_prompt_logprobs_chunked_prefill(
    vllm_runner,
    model,
    chunked_prefill_token_size: int,
    example_prompts,
373
    monkeypatch,
374
):
375
376
377
378
    # VLLM V1 does not use incremental detokenization for
    # prompt logprobs, so this test strategy is irrelevant.
    monkeypatch.setenv("VLLM_USE_V1", "0")

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    max_num_seqs = 256
    enable_chunked_prefill = False
    max_num_batched_tokens = None
    if chunked_prefill_token_size != -1:
        enable_chunked_prefill = True
        max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
        max_num_batched_tokens = chunked_prefill_token_size

    with vllm_runner(model,
                     dtype="half",
                     max_logprobs=5,
                     gpu_memory_utilization=0.5,
                     enable_chunked_prefill=enable_chunked_prefill,
                     max_num_batched_tokens=max_num_batched_tokens,
                     max_num_seqs=max_num_seqs) as vllm_model:

        vllm_sampling_params = SamplingParams(max_tokens=10,
                                              logprobs=5,
                                              prompt_logprobs=5,
                                              temperature=0.0)
        vllm_results = vllm_model.model.generate(
            example_prompts, sampling_params=vllm_sampling_params)

        for idx, result in enumerate(vllm_results):
            assert result.prompt_logprobs is not None
            assert result.prompt_logprobs[0] is None

            # Compared detokenized prompts ids to original prompt.
            generated_string = ""
            for (prompt_token,
                 prompt_logprobs) in zip(result.prompt_token_ids[1:],
                                         result.prompt_logprobs[1:]):
                # prompt_logprobs is a dict of the token_id: logprob
                # We select the token_id corresponding to the actual prompt
                # Decoded token in the detokenized string corresponding to this
                # prompt token.
                generated_string += prompt_logprobs[prompt_token].decoded_token

            assert generated_string == example_prompts[idx], (
418
                "Detokenized prompt logprobs do not match original prompt")