detokenizer.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4

5
6
import tokenizers
from packaging import version
7
8
9
10
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast

11
from vllm.logger import init_logger
12
from vllm.tokenizers import TokenizerLike
13
from vllm.tokenizers.detokenizer_utils import (
14
15
16
    convert_prompt_ids_to_tokens,
    detokenize_incrementally,
)
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
from vllm.v1.engine import EngineCoreRequest
19
20
21

logger = init_logger(__name__)

22
23
# Only tokenizers >= 0.21.1 supports DecodeStream used for
# FastIncrementalDetokenizer.
24
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1")
25
26
27
28

# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"

29
30

class IncrementalDetokenizer:
31
32
    def __init__(self):
        self.token_ids: list[int] = []
33
34

    @property
35
    def output_token_ids(self) -> list[int]:
36
37
        return self.token_ids

38
39
40
    def num_output_tokens(self) -> int:
        return len(self.token_ids)

41
    def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
42
43
44
45
46
        self.token_ids.extend(new_token_ids)
        return None

    def get_next_output_text(self, finished: bool, delta: bool) -> str:
        return ""
47
48
49
50

    @classmethod
    def from_new_request(
        cls,
51
        tokenizer: TokenizerLike | None,
52
        request: EngineCoreRequest,
53
    ) -> "IncrementalDetokenizer":
54
55
        assert request.sampling_params is not None

56
        if tokenizer is None:
57
58
59
            # No tokenizer => skipping detokenization.
            return IncrementalDetokenizer()

60
        if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast):
61
62
63
64
65
66
67
68
69
70
            # Fast tokenizer => use tokenizers library DecodeStream.
            return FastIncrementalDetokenizer(tokenizer, request)

        # Fall back to slow python-based incremental detokenization.
        return SlowIncrementalDetokenizer(tokenizer, request)


class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
    def __init__(self, request: EngineCoreRequest):
        super().__init__()
71

72
73
        # Stop strings
        params = request.sampling_params
74
        assert params is not None
75
76
77
78
79
80
81
82
        stop_list: list[str]
        if params.stop is None:
            stop_list = []
        elif isinstance(params.stop, str):
            stop_list = [params.stop]
        else:
            stop_list = params.stop
        self.stop = stop_list
83
        self.min_tokens = params.min_tokens
84
        self.include_stop_str_in_output = params.include_stop_str_in_output
85
86
87

        # Number of chars to hold back when stop strings are to be excluded
        # from streamed output.
88
89
        if self.stop and not self.include_stop_str_in_output:
            self.stop_buffer_length = max(len(s) for s in self.stop) - 1
90
        else:
91
92
93
94
95
            self.stop_buffer_length = 0
        self._last_output_text_offset: int = 0

        # Generation data
        self.output_text = ""
96

97
    def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
98
99
100
        """
        Update RequestState for the request_id by:
            1) Detokenize the new token ids incrementally.
101
            2) Evaluate stop criteria.
102

103
104
        Return matched stop string or None.
        """
105
        if not new_token_ids:
106
            # Skip detokenization if no new token ids.
107
108
            return None

109
110
111
112
113
114
115
116
        if stop_terminated and not self.include_stop_str_in_output:
            # If stop-terminated, exclude last token from detokenization
            # based on include_stop_str_in_output parameter.
            skipped_stop_token_id = new_token_ids[-1]
            new_token_ids = new_token_ids[:-1]
        else:
            skipped_stop_token_id = None

117
        # 1) Detokenize the new token ids incrementally.
118
        stop_check_offset = len(self.output_text)
119
120
        for new_token_id in new_token_ids:
            self.token_ids.append(new_token_id)
121
            self.output_text += self.decode_next(new_token_id)
122
            # Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
123
            if self.min_tokens and self.num_output_tokens() <= self.min_tokens:
124
                stop_check_offset = len(self.output_text)
125

126
127
128
        if skipped_stop_token_id is not None:
            # Cleanup after skipping detokenization.
            self.token_ids.append(skipped_stop_token_id)
129
130

        # 2) Evaluate stop strings.
131
        stop_string = None
132
        if self.stop and self.num_output_tokens() > self.min_tokens:
133
            stop = check_stop_strings(
134
                output_text=self.output_text,
135
                new_char_count=len(self.output_text) - stop_check_offset,
136
137
138
139
                stop=self.stop,
                include_in_output=self.include_stop_str_in_output,
            )
            if stop is not None:
140
                stop_string, truncate_to = stop
141
142
143
                if truncate_to != -1:
                    self.output_text = self.output_text[:truncate_to]

144
        return stop_string
145

146
147
148
149
    @abstractmethod
    def decode_next(self, next_token_id: int) -> str:
        raise NotImplementedError

150
    def get_next_output_text(self, finished: bool, delta: bool) -> str:
151
152
153
154
155
156
        """If delta is True, only new text since the last call to
        this method is returned"""

        # We return the full output text if the sequence is finished.
        buffer_length = 0 if finished else self.stop_buffer_length
        if not delta:
157
158
159
160
161
            return (
                self.output_text[:-buffer_length]
                if buffer_length
                else (self.output_text)
            )
162
163
164
165
166
167
        length = len(self.output_text) - buffer_length
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""
168
169
170


class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
171
    def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest):
172
173
174
        super().__init__(request)

        sampling_params = request.sampling_params
175
        assert sampling_params is not None
176
177
178

        self.request_id = request.request_id
        self.skip_special_tokens = sampling_params.skip_special_tokens
179
        self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
180
181
182
183

        self.tokenizer: Tokenizer = tokenizer._tokenizer

        # Find a safe place to start.
184
185
        prompt_token_ids = request.prompt_token_ids or []
        prompt_suffix = prompt_token_ids
186
187
        prompt_len = len(prompt_suffix)
        if prompt_len > 4:
188
            for i in range(4, min(prompt_len + 1, 24)):
189
                suffix = prompt_token_ids[-i:]
190
                if "�" not in self.tokenizer.decode(suffix):
191
192
193
194
195
                    prompt_suffix = suffix
                    break

        # Prime the stream.
        for tid in prompt_suffix:
196
            self._protected_step(tid)
197
198
199

        self.spaces_between_special_tokens = (
            sampling_params.skip_special_tokens
200
201
            or sampling_params.spaces_between_special_tokens
        )
202
203
204
205

        if not self.spaces_between_special_tokens:
            # Store dict of added token ids so that we can suppress
            # the spaces between them.
206
207
208
            if (
                added_token_ids := getattr(self.tokenizer, "added_token_ids", None)
            ) is None:
209
210
                self.tokenizer.added_token_ids = added_token_ids = {
                    tid: tok.content
211
                    for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
212
213
214
215
216
217
218
219
220
221
                }

            if added_token_ids:
                self.last_special = False
                self.added_token_ids = added_token_ids
            else:
                # No added tokens.
                self.spaces_between_special_tokens = True

    def decode_next(self, next_token_id: int) -> str:
222
        token = self._protected_step(next_token_id)
223
224
225
226
227
228
229
230
231
232
233

        if not self.spaces_between_special_tokens:
            special_token = self.added_token_ids.get(next_token_id)
            is_special = special_token is not None
            if is_special and self.last_special:
                # Return raw token string without any prefixed spaces.
                token = special_token
            self.last_special = is_special

        return token or ""

234
    def _protected_step(self, next_token_id: int) -> str | None:
235
236
        try:
            token = self.stream.step(self.tokenizer, next_token_id)
237
        except (OverflowError, TypeError):
238
239
            # Handle rare observed overflow, still to be diagnosed.
            # See https://github.com/vllm-project/vllm/issues/21951.
240
            logger.exception("Encountered invalid token id: %r", next_token_id)
241
            token = None
242
        except Exception as e:
243
            if not str(e).startswith(INVALID_PREFIX_ERR_MSG):
244
245
246
247
248
249
250
                raise e
            # Recover from edge case where tokenizer can produce non-monotonic,
            # invalid UTF-8 output, which breaks the internal state of
            # tokenizers' DecodeStream.
            # See https://github.com/vllm-project/vllm/issues/17448.
            logger.warning(
                "Encountered invalid prefix detokenization error"
251
252
253
254
                " for request %s, resetting decode stream.",
                self.request_id,
            )
            self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
255
256
257
            token = self.stream.step(self.tokenizer, next_token_id)
        return token

258
259

class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
260
    def __init__(self, tokenizer: TokenizerLike, request: EngineCoreRequest):
261
262
263
        super().__init__(request)

        self.tokenizer = tokenizer
264
265
        params = request.sampling_params
        assert params is not None
266

267
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
268
269
            request.prompt_token_ids, request.prompt_embeds
        )
270

271
        # Metadata for incremental detokenization.
272
273
274
275
276
277
        if request.prompt_token_ids is not None:
            self.tokens, self.prefix_offset, self.read_offset = (
                convert_prompt_ids_to_tokens(
                    tokenizer=tokenizer,
                    prompt_ids=request.prompt_token_ids,
                    skip_special_tokens=params.skip_special_tokens,
278
279
                )
            )
280
281
282
283
        else:
            # Prompt embedding requests cannot be detokenized, in general.
            self.tokens = [""] * self.prompt_len
            self.prefix_offset = 0
284
            self.read_offset = 0
285

286
        self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len)
287
288

        self.skip_special_tokens = params.skip_special_tokens
289
        self.spaces_between_special_tokens = params.spaces_between_special_tokens
290
291
292

    @property
    def output_token_ids(self) -> list[int]:
293
294
295
296
297
        return (
            self.token_ids
            if not self.prompt_len
            else (self.token_ids[self.prompt_len :])
        )
298

299
300
301
    def num_output_tokens(self) -> int:
        return len(self.token_ids) - self.prompt_len

302
    def decode_next(self, next_token_id: int) -> str:
303
304
305
306
307
308
309
310
311
        new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally(
            tokenizer=self.tokenizer,
            all_input_ids=self.token_ids,
            prev_tokens=self.tokens,
            prefix_offset=self.prefix_offset,
            read_offset=self.read_offset,
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
        )
312
313
314
315
316
317

        self.tokens.extend(new_tokens)
        self.prefix_offset = prefix_offset
        self.read_offset = read_offset

        return decoded_text
318
319
320
321
322
323
324


def check_stop_strings(
    output_text: str,
    new_char_count: int,
    stop: list[str],
    include_in_output: bool,
325
) -> tuple[str, int] | None:
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    """Check if any stop strings are matched and truncate sequence
    output text accordingly.

    Returns tuple (stop_string, offset) if matched or else None.

    Where stop_string is the matched stop string and offset is the
    length to which output_text should be truncated, or -1 for no
    truncation.
    """
    if not new_char_count or not stop:
        return None

    for stop_str in stop:
        stop_string_len = len(stop_str)
        # Avoid searching already-searched text.
341
        stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len)
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        if stop_index == -1:
            continue

        if include_in_output:
            # Truncate to end of stop string.
            stop_index += stop_string_len
            if stop_index >= len(output_text):
                # No truncation required.
                return stop_str, -1

        # Truncate the output text to either the beginning
        # or end of the stop string.
        return stop_str, stop_index
    return None