test_detokenize.py 7.62 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Generator
5
from typing import Any
6

7
import pytest
8
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
9

10
from vllm.sampling_params import SamplingParams
11
from vllm.tokenizers.mistral import MistralTokenizer
12
from vllm.v1.engine import EngineCoreRequest
13
14
15
16
17
from vllm.v1.engine.detokenizer import (
    FastIncrementalDetokenizer,
    IncrementalDetokenizer,
    SlowIncrementalDetokenizer,
)
18
19
20
21

SPECIAL_TOKS_TRUTH = [
    "Some text with adjacent special tokens                <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>",  # noqa
]
22
23

TRUTH = [
24
25
    "Hello here, this is a simple test",
    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",  # noqa
26
27
28
29
30
31
    "我很感谢你的热情",
    # Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
    # for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
    # incomplete UTF-8 characters
    # see https://github.com/vllm-project/vllm/pull/9625
    "ပုံပြင်လေးပြောပြပါ်",
32
33
] + SPECIAL_TOKS_TRUTH

34
35
36
37
38
39
40
TOKENIZERS = [
    "facebook/opt-125m",
    "gpt2",
    "bigcode/tiny_starcoder_py",
    "EleutherAI/gpt-j-6b",
    "EleutherAI/pythia-70m",
    "bigscience/bloom-560m",
41
42
    # FIXME: mosaicml/mpt-7b has been deleted
    # "mosaicml/mpt-7b",
43
    "tiiuae/falcon-7b",
44
    "meta-llama/Llama-3.2-1B-Instruct",
45
    "codellama/CodeLlama-7b-hf",
46
    "mistralai/Pixtral-12B-2409",
47
48
49
]


50
51
52
53
54
55
def _run_incremental_decode(
    tokenizer,
    all_input_ids,
    skip_special_tokens: bool,
    starting_index: int,
    spaces_between_special_tokens: bool = True,
56
    fast: bool | None = None,
57
):
58
59
60
61
62
63
    prompt_token_ids = all_input_ids[:starting_index]

    params = SamplingParams(
        skip_special_tokens=skip_special_tokens,
        spaces_between_special_tokens=spaces_between_special_tokens,
    )
64
65
66
67
68
69
70
71
72
73
74
75
    request = EngineCoreRequest(
        request_id="",
        prompt_token_ids=prompt_token_ids,
        mm_features=None,
        sampling_params=params,
        pooling_params=None,
        eos_token_id=None,
        arrival_time=0.0,
        lora_request=None,
        cache_salt=None,
        data_parallel_rank=None,
    )
76
77

    if fast is None:
78
        detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
79
80
81
82
83
84
85
86
87
88
89
90
    elif fast:
        detokenizer = FastIncrementalDetokenizer(tokenizer, request)
    else:
        detokenizer = SlowIncrementalDetokenizer(tokenizer, request)

    output_text = ""
    for i, token_id in enumerate(all_input_ids[starting_index:]):
        detokenizer.update([token_id], False)
        finished = i == len(all_input_ids) - 1
        output_text += detokenizer.get_next_output_text(finished, delta=True)

    return output_text, detokenizer.output_token_ids
91
92


93
94
@pytest.fixture
def tokenizer(tokenizer_name):
95
96
97
98
99
    return (
        MistralTokenizer.from_pretrained(tokenizer_name)
        if "mistral" in tokenizer_name
        else AutoTokenizer.from_pretrained(tokenizer_name)
    )
100
101
102
103
104
105
106
107
108
109
110


@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
@pytest.mark.parametrize(
    "truth",
    [
        # Burmese text triggers an edge-case where tokens may map to bytes with
        # incomplete UTF-8 characters
        "ပုံပြင်လေးပြောပြပါ",
        # Using "URGENCY" since "CY" has token id 130282
        "URGENCY🌶️",
111
112
    ],
)
113
114
115
116
117
118
119
120
def test_mistral_edge_case(tokenizer, truth):
    """Test for a specific edge cases with V3-Tekken MistralTokenizer.

    See https://github.com/vllm-project/vllm/pull/9625
    """
    starting_index = 0
    all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids

121
122
123
124
    decoded_text, out_ids = _run_incremental_decode(
        tokenizer,
        all_input_ids,
        skip_special_tokens=True,
125
126
        starting_index=starting_index,
    )
127
    assert decoded_text == truth
128
    assert out_ids == all_input_ids[starting_index:]
129
130
131
132
133
134


@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
    if "mistral" in tokenizer_name:
        yield (
135
136
137
138
            True
            if request.param
            else pytest.skip("mistral doesn't support skip_special_tokens=False")
        )
139
    else:
140
        yield bool(request.param)
141
142


143
@pytest.mark.parametrize("truth", TRUTH)
144
@pytest.mark.parametrize("with_prompt", [True, False])
145
146
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
147
148
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
149
150
151
152
153
154
155
156
def test_decode_streaming(
    tokenizer,
    truth,
    with_prompt,
    skip_special_tokens,
    spaces_between_special_tokens,
    fast,
):
157
158
159
160
161
162
163
164
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    if skip_special_tokens and not spaces_between_special_tokens:
        pytest.skip()

    if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
        # Fix up inconsistency in fast/slow tokenizer behaviour.
165
166
167
168
169
170
171
172
173
174
175
176
177
        tokenizer.add_special_tokens(
            {
                "additional_special_tokens": [
                    at
                    for at in tokenizer._tokenizer.get_added_tokens_decoder().values()
                    if at.special
                ]
            }
        )

    extra_decode_args = (
        {}
        if not isinstance(tokenizer, PreTrainedTokenizer)
178
        else {"spaces_between_special_tokens": spaces_between_special_tokens}
179
    )
180
181
182
183
184
185

    truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
    if tokenizer.bos_token_id is not None:
        truth_tokens.insert(0, tokenizer.bos_token_id)
    truth_tokens.append(tokenizer.eos_token_id)

186
187
188
    new_truth = tokenizer.decode(
        truth_tokens, skip_special_tokens=skip_special_tokens, **extra_decode_args
    )
189

190
    if with_prompt:
191
        num_prompt_tokens = len(
192
193
            tokenizer(truth[: len(truth) // 2], add_special_tokens=False).input_ids
        )
194
195
196
197
198
        if tokenizer.bos_token_id is not None:
            num_prompt_tokens += 1

        prompt_input_ids = truth_tokens[:num_prompt_tokens]
        generated_input_ids = truth_tokens[num_prompt_tokens:]
199
200
        all_input_ids = prompt_input_ids + generated_input_ids
        starting_index = len(prompt_input_ids)
201
202
203
204
205
        prompt = tokenizer.decode(
            prompt_input_ids,
            skip_special_tokens=skip_special_tokens,
            **extra_decode_args,
        )
206

207
        generated = new_truth[len(prompt) :]
208
    else:
209
        generated = new_truth
210
        starting_index = 0
211
        all_input_ids = truth_tokens
212

213
    decoded_text, out_ids = _run_incremental_decode(
214
215
216
        tokenizer,
        all_input_ids,
        skip_special_tokens=skip_special_tokens,
217
218
        starting_index=starting_index,
        spaces_between_special_tokens=spaces_between_special_tokens,
219
220
        fast=fast,
    )
221

222
    assert decoded_text == generated
223
    assert out_ids == all_input_ids[starting_index:]
224

225
226
227
228
229
230
231
232

@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    decoded_text, out_ids = _run_incremental_decode(
233
234
        tokenizer,
        [len(tokenizer)],
235
236
237
        skip_special_tokens=True,
        starting_index=0,
        spaces_between_special_tokens=True,
238
239
        fast=fast,
    )
240

241
    assert decoded_text == ""
242
    assert out_ids == [len(tokenizer)]