test_detokenize.py 8.03 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Generator
5
from typing import Any
6

7
import os
8
import pytest
9
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
10

11
from vllm.sampling_params import SamplingParams
12
from vllm.tokenizers.mistral import MistralTokenizer
13
from vllm.v1.engine import EngineCoreRequest
14
15
16
17
18
from vllm.v1.engine.detokenizer import (
    FastIncrementalDetokenizer,
    IncrementalDetokenizer,
    SlowIncrementalDetokenizer,
)
19
from ..utils import models_path_prefix
20

21
22
23
SPECIAL_TOKS_TRUTH = [
    "Some text with adjacent special tokens                <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>",  # noqa
]
24
25

TRUTH = [
26
27
    "Hello here, this is a simple test",
    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",  # noqa
28
29
30
31
32
33
    "我很感谢你的热情",
    # Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
    # for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
    # incomplete UTF-8 characters
    # see https://github.com/vllm-project/vllm/pull/9625
    "ပုံပြင်လေးပြောပြပါ်",
34
35
] + SPECIAL_TOKS_TRUTH

36
TOKENIZERS = [
37
38
39
40
41
42
    os.path.join(models_path_prefix, "facebook/opt-125m"),
    os.path.join(models_path_prefix, "gpt2"),
    os.path.join(models_path_prefix, "bigcode/tiny_starcoder_py"),
    os.path.join(models_path_prefix, "EleutherAI/gpt-j-6b"),
    os.path.join(models_path_prefix, "EleutherAI/pythia-70m"),
    os.path.join(models_path_prefix, "bigscience/bloom-560m"),
43
44
    # FIXME: mosaicml/mpt-7b has been deleted
    # "mosaicml/mpt-7b",
45
    os.path.join(models_path_prefix, "tiiuae/falcon-7b"),
zhuwenwen's avatar
zhuwenwen committed
46
    os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
47
    os.path.join(models_path_prefix, "codellama/CodeLlama-7b-hf"),
zhuwenwen's avatar
zhuwenwen committed
48
    os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409"),
49
50
51
]


52
53
54
55
56
57
def _run_incremental_decode(
    tokenizer,
    all_input_ids,
    skip_special_tokens: bool,
    starting_index: int,
    spaces_between_special_tokens: bool = True,
58
    fast: bool | None = None,
59
):
60
61
62
63
64
65
    prompt_token_ids = all_input_ids[:starting_index]

    params = SamplingParams(
        skip_special_tokens=skip_special_tokens,
        spaces_between_special_tokens=spaces_between_special_tokens,
    )
66
67
68
69
70
71
72
73
74
75
76
77
    request = EngineCoreRequest(
        request_id="",
        prompt_token_ids=prompt_token_ids,
        mm_features=None,
        sampling_params=params,
        pooling_params=None,
        eos_token_id=None,
        arrival_time=0.0,
        lora_request=None,
        cache_salt=None,
        data_parallel_rank=None,
    )
78
79

    if fast is None:
80
        detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
81
82
83
84
85
86
87
88
89
90
91
92
    elif fast:
        detokenizer = FastIncrementalDetokenizer(tokenizer, request)
    else:
        detokenizer = SlowIncrementalDetokenizer(tokenizer, request)

    output_text = ""
    for i, token_id in enumerate(all_input_ids[starting_index:]):
        detokenizer.update([token_id], False)
        finished = i == len(all_input_ids) - 1
        output_text += detokenizer.get_next_output_text(finished, delta=True)

    return output_text, detokenizer.output_token_ids
93
94


95
96
@pytest.fixture
def tokenizer(tokenizer_name):
97
98
99
100
101
    return (
        MistralTokenizer.from_pretrained(tokenizer_name)
        if "mistral" in tokenizer_name
        else AutoTokenizer.from_pretrained(tokenizer_name)
    )
102
103


104
@pytest.mark.parametrize("tokenizer_name", [os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409")])
105
106
107
108
109
110
111
112
@pytest.mark.parametrize(
    "truth",
    [
        # Burmese text triggers an edge-case where tokens may map to bytes with
        # incomplete UTF-8 characters
        "ပုံပြင်လေးပြောပြပါ",
        # Using "URGENCY" since "CY" has token id 130282
        "URGENCY🌶️",
113
114
    ],
)
115
116
117
118
119
120
121
122
def test_mistral_edge_case(tokenizer, truth):
    """Test for a specific edge cases with V3-Tekken MistralTokenizer.

    See https://github.com/vllm-project/vllm/pull/9625
    """
    starting_index = 0
    all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids

123
124
125
126
    decoded_text, out_ids = _run_incremental_decode(
        tokenizer,
        all_input_ids,
        skip_special_tokens=True,
127
128
        starting_index=starting_index,
    )
129
    assert decoded_text == truth
130
    assert out_ids == all_input_ids[starting_index:]
131
132
133
134
135
136


@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
    if "mistral" in tokenizer_name:
        yield (
137
138
139
140
            True
            if request.param
            else pytest.skip("mistral doesn't support skip_special_tokens=False")
        )
141
    else:
142
        yield bool(request.param)
143
144


145
@pytest.mark.parametrize("truth", TRUTH)
146
@pytest.mark.parametrize("with_prompt", [True, False])
147
148
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
149
150
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
151
152
153
154
155
156
157
158
def test_decode_streaming(
    tokenizer,
    truth,
    with_prompt,
    skip_special_tokens,
    spaces_between_special_tokens,
    fast,
):
159
160
161
162
163
164
165
166
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    if skip_special_tokens and not spaces_between_special_tokens:
        pytest.skip()

    if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
        # Fix up inconsistency in fast/slow tokenizer behaviour.
167
168
169
170
171
172
173
174
175
176
177
178
179
        tokenizer.add_special_tokens(
            {
                "additional_special_tokens": [
                    at
                    for at in tokenizer._tokenizer.get_added_tokens_decoder().values()
                    if at.special
                ]
            }
        )

    extra_decode_args = (
        {}
        if not isinstance(tokenizer, PreTrainedTokenizer)
180
        else {"spaces_between_special_tokens": spaces_between_special_tokens}
181
    )
182
183
184
185
186
187

    truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
    if tokenizer.bos_token_id is not None:
        truth_tokens.insert(0, tokenizer.bos_token_id)
    truth_tokens.append(tokenizer.eos_token_id)

188
189
190
    new_truth = tokenizer.decode(
        truth_tokens, skip_special_tokens=skip_special_tokens, **extra_decode_args
    )
191

192
    if with_prompt:
193
        num_prompt_tokens = len(
194
195
            tokenizer(truth[: len(truth) // 2], add_special_tokens=False).input_ids
        )
196
197
198
199
200
        if tokenizer.bos_token_id is not None:
            num_prompt_tokens += 1

        prompt_input_ids = truth_tokens[:num_prompt_tokens]
        generated_input_ids = truth_tokens[num_prompt_tokens:]
201
202
        all_input_ids = prompt_input_ids + generated_input_ids
        starting_index = len(prompt_input_ids)
203
204
205
206
207
        prompt = tokenizer.decode(
            prompt_input_ids,
            skip_special_tokens=skip_special_tokens,
            **extra_decode_args,
        )
208

209
        generated = new_truth[len(prompt) :]
210
    else:
211
        generated = new_truth
212
        starting_index = 0
213
        all_input_ids = truth_tokens
214

215
    decoded_text, out_ids = _run_incremental_decode(
216
217
218
        tokenizer,
        all_input_ids,
        skip_special_tokens=skip_special_tokens,
219
220
        starting_index=starting_index,
        spaces_between_special_tokens=spaces_between_special_tokens,
221
222
        fast=fast,
    )
223

224
    assert decoded_text == generated
225
    assert out_ids == all_input_ids[starting_index:]
226

227
228
229
230
231
232

@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()
233

234
    decoded_text, out_ids = _run_incremental_decode(
235
236
        tokenizer,
        [len(tokenizer)],
237
238
239
        skip_special_tokens=True,
        starting_index=0,
        spaces_between_special_tokens=True,
240
241
        fast=fast,
    )
242

243
    assert decoded_text == ""
244
    assert out_ids == [len(tokenizer)]