detokenizer.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4

5
6
import tokenizers
from packaging import version
7
8
9
10
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast

11
12
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
13
14
15
16
    AnyTokenizer,
    convert_prompt_ids_to_tokens,
    detokenize_incrementally,
)
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
from vllm.v1.engine import EngineCoreRequest
19
20
21

logger = init_logger(__name__)

22
23
# Only tokenizers >= 0.21.1 supports DecodeStream used for
# FastIncrementalDetokenizer.
24
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1")
25
26
27
28

# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"

29
30

class IncrementalDetokenizer:
31
32
    def __init__(self):
        self.token_ids: list[int] = []
33
34

    @property
35
    def output_token_ids(self) -> list[int]:
36
37
        return self.token_ids

38
    def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
39
40
41
42
43
        self.token_ids.extend(new_token_ids)
        return None

    def get_next_output_text(self, finished: bool, delta: bool) -> str:
        return ""
44
45
46
47

    @classmethod
    def from_new_request(
        cls,
48
        tokenizer: AnyTokenizer | None,
49
        request: EngineCoreRequest,
50
    ) -> "IncrementalDetokenizer":
51
52
        assert request.sampling_params is not None

53
        if tokenizer is None:
54
55
56
            # No tokenizer => skipping detokenization.
            return IncrementalDetokenizer()

57
        if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast):
58
59
60
61
62
63
64
65
66
67
            # Fast tokenizer => use tokenizers library DecodeStream.
            return FastIncrementalDetokenizer(tokenizer, request)

        # Fall back to slow python-based incremental detokenization.
        return SlowIncrementalDetokenizer(tokenizer, request)


class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
    def __init__(self, request: EngineCoreRequest):
        super().__init__()
68

69
70
        # Stop strings
        params = request.sampling_params
71
        assert params is not None
72
        self.stop = stop = params.stop
73
        self.min_tokens = params.min_tokens
74
        self.include_stop_str_in_output = params.include_stop_str_in_output
75
76
77

        # Number of chars to hold back when stop strings are to be excluded
        # from streamed output.
78
79
        if stop and not self.include_stop_str_in_output:
            self.stop_buffer_length = max(len(s) for s in stop) - 1
80
        else:
81
82
83
84
85
            self.stop_buffer_length = 0
        self._last_output_text_offset: int = 0

        # Generation data
        self.output_text = ""
86

87
    def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
88
89
90
        """
        Update RequestState for the request_id by:
            1) Detokenize the new token ids incrementally.
91
            2) Evaluate stop criteria.
92

93
94
        Return matched stop string or None.
        """
95
        if not new_token_ids:
96
            # Skip detokenization if no new token ids.
97
98
            return None

99
100
101
102
103
104
105
106
        if stop_terminated and not self.include_stop_str_in_output:
            # If stop-terminated, exclude last token from detokenization
            # based on include_stop_str_in_output parameter.
            skipped_stop_token_id = new_token_ids[-1]
            new_token_ids = new_token_ids[:-1]
        else:
            skipped_stop_token_id = None

107
108
109
        # 1) Detokenize the new token ids incrementally.
        # TODO(woosuk): This method becomes very inefficient when the number of
        # new_token_ids is more than 1. We need to optimize this.
110
        stop_check_offset = len(self.output_text)
111
112
        for new_token_id in new_token_ids:
            self.token_ids.append(new_token_id)
113
            self.output_text += self.decode_next(new_token_id)
114
            # Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
115
            if self.min_tokens and len(self.output_token_ids) <= self.min_tokens:
116
                stop_check_offset = len(self.output_text)
117

118
119
120
        if skipped_stop_token_id is not None:
            # Cleanup after skipping detokenization.
            self.token_ids.append(skipped_stop_token_id)
121
122

        # 2) Evaluate stop strings.
123
        stop_string = None
124
        if self.stop and len(self.output_token_ids) > self.min_tokens:
125
            stop = check_stop_strings(
126
                output_text=self.output_text,
127
                new_char_count=len(self.output_text) - stop_check_offset,
128
129
130
131
                stop=self.stop,
                include_in_output=self.include_stop_str_in_output,
            )
            if stop is not None:
132
                stop_string, truncate_to = stop
133
134
135
                if truncate_to != -1:
                    self.output_text = self.output_text[:truncate_to]

136
        return stop_string
137

138
139
140
141
    @abstractmethod
    def decode_next(self, next_token_id: int) -> str:
        raise NotImplementedError

142
    def get_next_output_text(self, finished: bool, delta: bool) -> str:
143
144
145
146
147
148
        """If delta is True, only new text since the last call to
        this method is returned"""

        # We return the full output text if the sequence is finished.
        buffer_length = 0 if finished else self.stop_buffer_length
        if not delta:
149
150
151
152
153
            return (
                self.output_text[:-buffer_length]
                if buffer_length
                else (self.output_text)
            )
154
155
156
157
158
159
        length = len(self.output_text) - buffer_length
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""
160
161
162


class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
163
    def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest):
164
165
166
        super().__init__(request)

        sampling_params = request.sampling_params
167
        assert sampling_params is not None
168
169
170

        self.request_id = request.request_id
        self.skip_special_tokens = sampling_params.skip_special_tokens
171
        self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
172
173
174
175

        self.tokenizer: Tokenizer = tokenizer._tokenizer

        # Find a safe place to start.
176
177
        prompt_token_ids = request.prompt_token_ids or []
        prompt_suffix = prompt_token_ids
178
179
        prompt_len = len(prompt_suffix)
        if prompt_len > 4:
180
            for i in range(4, min(prompt_len + 1, 24)):
181
                suffix = prompt_token_ids[-i:]
182
                if "�" not in self.tokenizer.decode(suffix):
183
184
185
186
187
                    prompt_suffix = suffix
                    break

        # Prime the stream.
        for tid in prompt_suffix:
188
            self._protected_step(tid)
189
190
191

        self.spaces_between_special_tokens = (
            sampling_params.skip_special_tokens
192
193
            or sampling_params.spaces_between_special_tokens
        )
194
195
196
197

        if not self.spaces_between_special_tokens:
            # Store dict of added token ids so that we can suppress
            # the spaces between them.
198
199
200
            if (
                added_token_ids := getattr(self.tokenizer, "added_token_ids", None)
            ) is None:
201
202
                self.tokenizer.added_token_ids = added_token_ids = {
                    tid: tok.content
203
                    for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
204
205
206
207
208
209
210
211
212
213
                }

            if added_token_ids:
                self.last_special = False
                self.added_token_ids = added_token_ids
            else:
                # No added tokens.
                self.spaces_between_special_tokens = True

    def decode_next(self, next_token_id: int) -> str:
214
        token = self._protected_step(next_token_id)
215
216
217
218
219
220
221
222
223
224
225

        if not self.spaces_between_special_tokens:
            special_token = self.added_token_ids.get(next_token_id)
            is_special = special_token is not None
            if is_special and self.last_special:
                # Return raw token string without any prefixed spaces.
                token = special_token
            self.last_special = is_special

        return token or ""

226
    def _protected_step(self, next_token_id: int) -> str | None:
227
228
        try:
            token = self.stream.step(self.tokenizer, next_token_id)
229
        except (OverflowError, TypeError):
230
231
            # Handle rare observed overflow, still to be diagnosed.
            # See https://github.com/vllm-project/vllm/issues/21951.
232
            logger.exception("Encountered invalid token id: %r", next_token_id)
233
            token = None
234
        except Exception as e:
235
            if not str(e).startswith(INVALID_PREFIX_ERR_MSG):
236
237
238
239
240
241
242
                raise e
            # Recover from edge case where tokenizer can produce non-monotonic,
            # invalid UTF-8 output, which breaks the internal state of
            # tokenizers' DecodeStream.
            # See https://github.com/vllm-project/vllm/issues/17448.
            logger.warning(
                "Encountered invalid prefix detokenization error"
243
244
245
246
                " for request %s, resetting decode stream.",
                self.request_id,
            )
            self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
247
248
249
            token = self.stream.step(self.tokenizer, next_token_id)
        return token

250
251
252
253
254
255

class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
    def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
        super().__init__(request)

        self.tokenizer = tokenizer
256
257
        params = request.sampling_params
        assert params is not None
258

259
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
260
261
            request.prompt_token_ids, request.prompt_embeds
        )
262

263
        # Metadata for incremental detokenization.
264
265
266
267
268
269
        if request.prompt_token_ids is not None:
            self.tokens, self.prefix_offset, self.read_offset = (
                convert_prompt_ids_to_tokens(
                    tokenizer=tokenizer,
                    prompt_ids=request.prompt_token_ids,
                    skip_special_tokens=params.skip_special_tokens,
270
271
                )
            )
272
273
274
275
276
277
        else:
            # Prompt embedding requests cannot be detokenized, in general.
            self.tokens = [""] * self.prompt_len
            self.prefix_offset = 0
            self.read_offest = 0

278
        self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len)
279
280

        self.skip_special_tokens = params.skip_special_tokens
281
        self.spaces_between_special_tokens = params.spaces_between_special_tokens
282
283
284

    @property
    def output_token_ids(self) -> list[int]:
285
286
287
288
289
        return (
            self.token_ids
            if not self.prompt_len
            else (self.token_ids[self.prompt_len :])
        )
290
291

    def decode_next(self, next_token_id: int) -> str:
292
293
294
295
296
297
298
299
300
        new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally(
            tokenizer=self.tokenizer,
            all_input_ids=self.token_ids,
            prev_tokens=self.tokens,
            prefix_offset=self.prefix_offset,
            read_offset=self.read_offset,
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
        )
301
302
303
304
305
306

        self.tokens.extend(new_tokens)
        self.prefix_offset = prefix_offset
        self.read_offset = read_offset

        return decoded_text
307
308
309
310
311
312
313


def check_stop_strings(
    output_text: str,
    new_char_count: int,
    stop: list[str],
    include_in_output: bool,
314
) -> tuple[str, int] | None:
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    """Check if any stop strings are matched and truncate sequence
    output text accordingly.

    Returns tuple (stop_string, offset) if matched or else None.

    Where stop_string is the matched stop string and offset is the
    length to which output_text should be truncated, or -1 for no
    truncation.
    """
    if not new_char_count or not stop:
        return None

    for stop_str in stop:
        stop_string_len = len(stop_str)
        # Avoid searching already-searched text.
330
        stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len)
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        if stop_index == -1:
            continue

        if include_in_output:
            # Truncate to end of stop string.
            stop_index += stop_string_len
            if stop_index >= len(output_text):
                # No truncation required.
                return stop_str, -1

        # Truncate the output text to either the beginning
        # or end of the stop string.
        return stop_str, stop_index
    return None