test_detokenize.py 7.57 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Generator
5
from typing import Any
6

7
import pytest
8
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
9

10
from vllm.sampling_params import SamplingParams
11
from vllm.tokenizers.mistral import MistralTokenizer
12
from vllm.v1.engine import EngineCoreRequest
13
14
15
16
17
from vllm.v1.engine.detokenizer import (
    FastIncrementalDetokenizer,
    IncrementalDetokenizer,
    SlowIncrementalDetokenizer,
)
18
19
20
21

SPECIAL_TOKS_TRUTH = [
    "Some text with adjacent special tokens                <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>",  # noqa
]
22
23

TRUTH = [
24
25
    "Hello here, this is a simple test",
    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",  # noqa
26
27
28
29
30
31
    "我很感谢你的热情",
    # Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
    # for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
    # incomplete UTF-8 characters
    # see https://github.com/vllm-project/vllm/pull/9625
    "ပုံပြင်လေးပြောပြပါ်",
32
33
] + SPECIAL_TOKS_TRUTH

34
35
36
37
38
39
40
41
42
TOKENIZERS = [
    "facebook/opt-125m",
    "gpt2",
    "bigcode/tiny_starcoder_py",
    "EleutherAI/gpt-j-6b",
    "EleutherAI/pythia-70m",
    "bigscience/bloom-560m",
    "mosaicml/mpt-7b",
    "tiiuae/falcon-7b",
43
    "meta-llama/Llama-3.2-1B-Instruct",
44
    "codellama/CodeLlama-7b-hf",
45
    "mistralai/Pixtral-12B-2409",
46
47
48
]


49
50
51
52
53
54
def _run_incremental_decode(
    tokenizer,
    all_input_ids,
    skip_special_tokens: bool,
    starting_index: int,
    spaces_between_special_tokens: bool = True,
55
    fast: bool | None = None,
56
):
57
58
59
60
61
62
    prompt_token_ids = all_input_ids[:starting_index]

    params = SamplingParams(
        skip_special_tokens=skip_special_tokens,
        spaces_between_special_tokens=spaces_between_special_tokens,
    )
63
64
65
66
67
68
69
70
71
72
73
74
    request = EngineCoreRequest(
        request_id="",
        prompt_token_ids=prompt_token_ids,
        mm_features=None,
        sampling_params=params,
        pooling_params=None,
        eos_token_id=None,
        arrival_time=0.0,
        lora_request=None,
        cache_salt=None,
        data_parallel_rank=None,
    )
75
76

    if fast is None:
77
        detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
78
79
80
81
82
83
84
85
86
87
88
89
    elif fast:
        detokenizer = FastIncrementalDetokenizer(tokenizer, request)
    else:
        detokenizer = SlowIncrementalDetokenizer(tokenizer, request)

    output_text = ""
    for i, token_id in enumerate(all_input_ids[starting_index:]):
        detokenizer.update([token_id], False)
        finished = i == len(all_input_ids) - 1
        output_text += detokenizer.get_next_output_text(finished, delta=True)

    return output_text, detokenizer.output_token_ids
90
91


92
93
@pytest.fixture
def tokenizer(tokenizer_name):
94
95
96
97
98
    return (
        MistralTokenizer.from_pretrained(tokenizer_name)
        if "mistral" in tokenizer_name
        else AutoTokenizer.from_pretrained(tokenizer_name)
    )
99
100
101
102
103
104
105
106
107
108
109


@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
@pytest.mark.parametrize(
    "truth",
    [
        # Burmese text triggers an edge-case where tokens may map to bytes with
        # incomplete UTF-8 characters
        "ပုံပြင်လေးပြောပြပါ",
        # Using "URGENCY" since "CY" has token id 130282
        "URGENCY🌶️",
110
111
    ],
)
112
113
114
115
116
117
118
119
def test_mistral_edge_case(tokenizer, truth):
    """Test for a specific edge cases with V3-Tekken MistralTokenizer.

    See https://github.com/vllm-project/vllm/pull/9625
    """
    starting_index = 0
    all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids

120
121
122
123
    decoded_text, out_ids = _run_incremental_decode(
        tokenizer,
        all_input_ids,
        skip_special_tokens=True,
124
125
        starting_index=starting_index,
    )
126
    assert decoded_text == truth
127
    assert out_ids == all_input_ids[starting_index:]
128
129
130
131
132
133


@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
    if "mistral" in tokenizer_name:
        yield (
134
135
136
137
            True
            if request.param
            else pytest.skip("mistral doesn't support skip_special_tokens=False")
        )
138
    else:
139
        yield bool(request.param)
140
141


142
@pytest.mark.parametrize("truth", TRUTH)
143
@pytest.mark.parametrize("with_prompt", [True, False])
144
145
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
146
147
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
148
149
150
151
152
153
154
155
def test_decode_streaming(
    tokenizer,
    truth,
    with_prompt,
    skip_special_tokens,
    spaces_between_special_tokens,
    fast,
):
156
157
158
159
160
161
162
163
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    if skip_special_tokens and not spaces_between_special_tokens:
        pytest.skip()

    if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
        # Fix up inconsistency in fast/slow tokenizer behaviour.
164
165
166
167
168
169
170
171
172
173
174
175
176
        tokenizer.add_special_tokens(
            {
                "additional_special_tokens": [
                    at
                    for at in tokenizer._tokenizer.get_added_tokens_decoder().values()
                    if at.special
                ]
            }
        )

    extra_decode_args = (
        {}
        if not isinstance(tokenizer, PreTrainedTokenizer)
177
        else {"spaces_between_special_tokens": spaces_between_special_tokens}
178
    )
179
180
181
182
183
184

    truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
    if tokenizer.bos_token_id is not None:
        truth_tokens.insert(0, tokenizer.bos_token_id)
    truth_tokens.append(tokenizer.eos_token_id)

185
186
187
    new_truth = tokenizer.decode(
        truth_tokens, skip_special_tokens=skip_special_tokens, **extra_decode_args
    )
188

189
    if with_prompt:
190
        num_prompt_tokens = len(
191
192
            tokenizer(truth[: len(truth) // 2], add_special_tokens=False).input_ids
        )
193
194
195
196
197
        if tokenizer.bos_token_id is not None:
            num_prompt_tokens += 1

        prompt_input_ids = truth_tokens[:num_prompt_tokens]
        generated_input_ids = truth_tokens[num_prompt_tokens:]
198
199
        all_input_ids = prompt_input_ids + generated_input_ids
        starting_index = len(prompt_input_ids)
200
201
202
203
204
        prompt = tokenizer.decode(
            prompt_input_ids,
            skip_special_tokens=skip_special_tokens,
            **extra_decode_args,
        )
205

206
        generated = new_truth[len(prompt) :]
207
    else:
208
        generated = new_truth
209
        starting_index = 0
210
        all_input_ids = truth_tokens
211

212
    decoded_text, out_ids = _run_incremental_decode(
213
214
215
        tokenizer,
        all_input_ids,
        skip_special_tokens=skip_special_tokens,
216
217
        starting_index=starting_index,
        spaces_between_special_tokens=spaces_between_special_tokens,
218
219
        fast=fast,
    )
220

221
    assert decoded_text == generated
222
    assert out_ids == all_input_ids[starting_index:]
223

224
225
226
227
228
229
230
231

@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    decoded_text, out_ids = _run_incremental_decode(
232
233
        tokenizer,
        [len(tokenizer)],
234
235
236
        skip_special_tokens=True,
        starting_index=0,
        spaces_between_special_tokens=True,
237
238
        fast=fast,
    )
239

240
    assert decoded_text == ""
241
    assert out_ids == [len(tokenizer)]