test_detokenize.py 7.97 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Generator
from typing import Any, Optional
6

7
import pytest
8
9
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
10

11
from vllm.sampling_params import SamplingParams
12
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
13
14
15
16
17
18
19
20
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
                                        IncrementalDetokenizer,
                                        SlowIncrementalDetokenizer)

SPECIAL_TOKS_TRUTH = [
    "Some text with adjacent special tokens                <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>",  # noqa
]
21
22

TRUTH = [
23
24
    "Hello here, this is a simple test",
    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",  # noqa
25
26
27
28
29
30
    "我很感谢你的热情",
    # Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
    # for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
    # incomplete UTF-8 characters
    # see https://github.com/vllm-project/vllm/pull/9625
    "ပုံပြင်လေးပြောပြပါ်",
31
32
] + SPECIAL_TOKS_TRUTH

33
34
35
36
37
38
39
40
41
TOKENIZERS = [
    "facebook/opt-125m",
    "gpt2",
    "bigcode/tiny_starcoder_py",
    "EleutherAI/gpt-j-6b",
    "EleutherAI/pythia-70m",
    "bigscience/bloom-560m",
    "mosaicml/mpt-7b",
    "tiiuae/falcon-7b",
42
    "meta-llama/Llama-3.2-1B-Instruct",
43
    "codellama/CodeLlama-7b-hf",
44
    "mistralai/Pixtral-12B-2409",
45
46
47
]


48
49
50
51
52
53
54
55
56
57
58
59
60
def _run_incremental_decode(tokenizer,
                            all_input_ids,
                            skip_special_tokens: bool,
                            starting_index: int,
                            spaces_between_special_tokens: bool = True,
                            fast: Optional[bool] = None):

    prompt_token_ids = all_input_ids[:starting_index]

    params = SamplingParams(
        skip_special_tokens=skip_special_tokens,
        spaces_between_special_tokens=spaces_between_special_tokens,
    )
61
62
63
64
65
66
67
68
    request = EngineCoreRequest(request_id="",
                                prompt_token_ids=prompt_token_ids,
                                mm_features=None,
                                sampling_params=params,
                                pooling_params=None,
                                eos_token_id=None,
                                arrival_time=0.0,
                                lora_request=None,
69
70
                                cache_salt=None,
                                data_parallel_rank=None)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    if fast is None:
        detokenizer = IncrementalDetokenizer.from_new_request(
            tokenizer, request)
    elif fast:
        detokenizer = FastIncrementalDetokenizer(tokenizer, request)
    else:
        detokenizer = SlowIncrementalDetokenizer(tokenizer, request)

    output_text = ""
    for i, token_id in enumerate(all_input_ids[starting_index:]):
        detokenizer.update([token_id], False)
        finished = i == len(all_input_ids) - 1
        output_text += detokenizer.get_next_output_text(finished, delta=True)

    return output_text, detokenizer.output_token_ids
87
88


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@pytest.fixture
def tokenizer(tokenizer_name):
    return (MistralTokenizer.from_pretrained(tokenizer_name)
            if "mistral" in tokenizer_name else
            AutoTokenizer.from_pretrained(tokenizer_name))


@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
@pytest.mark.parametrize(
    "truth",
    [
        # Burmese text triggers an edge-case where tokens may map to bytes with
        # incomplete UTF-8 characters
        "ပုံပြင်လေးပြောပြပါ",
        # Using "URGENCY" since "CY" has token id 130282
        "URGENCY🌶️",
    ])
def test_mistral_edge_case(tokenizer, truth):
    """Test for a specific edge cases with V3-Tekken MistralTokenizer.

    See https://github.com/vllm-project/vllm/pull/9625
    """
    starting_index = 0
    all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids

114
115
116
117
118
    decoded_text, out_ids = _run_incremental_decode(
        tokenizer,
        all_input_ids,
        skip_special_tokens=True,
        starting_index=starting_index)
119
    assert decoded_text == truth
120
    assert out_ids == all_input_ids[starting_index:]
121
122
123
124
125
126


@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
    if "mistral" in tokenizer_name:
        yield (
127
            True if request.param else
128
129
            pytest.skip("mistral doesn't support skip_special_tokens=False"))
    else:
130
        yield bool(request.param)
131
132


133
@pytest.mark.parametrize("truth", TRUTH)
134
@pytest.mark.parametrize("with_prompt", [True, False])
135
136
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
                          spaces_between_special_tokens, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    if skip_special_tokens and not spaces_between_special_tokens:
        pytest.skip()

    if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
        # Fix up inconsistency in fast/slow tokenizer behaviour.
        tokenizer.add_special_tokens({
            "additional_special_tokens": [
                at for at in
                tokenizer._tokenizer.get_added_tokens_decoder().values()
                if at.special
            ]
        })

    extra_decode_args = {} if not isinstance(tokenizer,  PreTrainedTokenizer) \
        else {"spaces_between_special_tokens": spaces_between_special_tokens}

    truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
    if tokenizer.bos_token_id is not None:
        truth_tokens.insert(0, tokenizer.bos_token_id)
    truth_tokens.append(tokenizer.eos_token_id)

    new_truth = tokenizer.decode(truth_tokens,
                                 skip_special_tokens=skip_special_tokens,
                                 **extra_decode_args)

169
    if with_prompt:
170
171
172
173
174
175
176
177
        num_prompt_tokens = len(
            tokenizer(truth[:len(truth) // 2],
                      add_special_tokens=False).input_ids)
        if tokenizer.bos_token_id is not None:
            num_prompt_tokens += 1

        prompt_input_ids = truth_tokens[:num_prompt_tokens]
        generated_input_ids = truth_tokens[num_prompt_tokens:]
178
179
180
        all_input_ids = prompt_input_ids + generated_input_ids
        starting_index = len(prompt_input_ids)
        prompt = tokenizer.decode(prompt_input_ids,
181
182
183
184
                                  skip_special_tokens=skip_special_tokens,
                                  **extra_decode_args)

        generated = new_truth[len(prompt):]
185
    else:
186
        generated = new_truth
187
        starting_index = 0
188
        all_input_ids = truth_tokens
189

190
    decoded_text, out_ids = _run_incremental_decode(
191
192
193
        tokenizer,
        all_input_ids,
        skip_special_tokens=skip_special_tokens,
194
195
196
        starting_index=starting_index,
        spaces_between_special_tokens=spaces_between_special_tokens,
        fast=fast)
197

198
    assert decoded_text == generated
199
    assert out_ids == all_input_ids[starting_index:]
200

201
202
203
204
205
206
207
208

@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
    if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
        pytest.skip()

    decoded_text, out_ids = _run_incremental_decode(
209
        tokenizer, [len(tokenizer)],
210
211
212
213
        skip_special_tokens=True,
        starting_index=0,
        spaces_between_special_tokens=True,
        fast=fast)
214
215

    assert decoded_text == ''
216
    assert out_ids == [len(tokenizer)]