"vscode:/vscode.git/clone" did not exist on "59fff4a01ae0f5c887cc547af6b49a9b028b4c70"
detokenizer.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4

5
6
import tokenizers
from packaging import version
7
8
9
10
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast

11
from vllm.logger import init_logger
12
from vllm.tokenizers import TokenizerLike
13
from vllm.tokenizers.detokenizer_utils import (
14
15
16
    convert_prompt_ids_to_tokens,
    detokenize_incrementally,
)
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
from vllm.v1.engine import EngineCoreRequest
19
20
21

logger = init_logger(__name__)

22
23
# Only tokenizers >= 0.21.1 supports DecodeStream used for
# FastIncrementalDetokenizer.
24
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1")
25
26
27
28

# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"

29
30

class IncrementalDetokenizer:
31
32
    def __init__(self):
        self.token_ids: list[int] = []
33
34

    @property
35
    def output_token_ids(self) -> list[int]:
36
37
        return self.token_ids

38
    def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
39
40
41
42
43
        self.token_ids.extend(new_token_ids)
        return None

    def get_next_output_text(self, finished: bool, delta: bool) -> str:
        return ""
44
45
46
47

    @classmethod
    def from_new_request(
        cls,
48
        tokenizer: TokenizerLike | None,
49
        request: EngineCoreRequest,
50
    ) -> "IncrementalDetokenizer":
51
52
        assert request.sampling_params is not None

53
        if tokenizer is None:
54
55
56
            # No tokenizer => skipping detokenization.
            return IncrementalDetokenizer()

57
        if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast):
58
59
60
61
62
63
64
65
66
67
            # Fast tokenizer => use tokenizers library DecodeStream.
            return FastIncrementalDetokenizer(tokenizer, request)

        # Fall back to slow python-based incremental detokenization.
        return SlowIncrementalDetokenizer(tokenizer, request)


class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
    def __init__(self, request: EngineCoreRequest):
        super().__init__()
68

69
70
        # Stop strings
        params = request.sampling_params
71
        assert params is not None
72
73
74
75
76
77
78
79
        stop_list: list[str]
        if params.stop is None:
            stop_list = []
        elif isinstance(params.stop, str):
            stop_list = [params.stop]
        else:
            stop_list = params.stop
        self.stop = stop_list
80
        self.min_tokens = params.min_tokens
81
        self.include_stop_str_in_output = params.include_stop_str_in_output
82
83
84

        # Number of chars to hold back when stop strings are to be excluded
        # from streamed output.
85
86
        if self.stop and not self.include_stop_str_in_output:
            self.stop_buffer_length = max(len(s) for s in self.stop) - 1
87
        else:
88
89
90
91
92
            self.stop_buffer_length = 0
        self._last_output_text_offset: int = 0

        # Generation data
        self.output_text = ""
93

94
    def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
95
96
97
        """
        Update RequestState for the request_id by:
            1) Detokenize the new token ids incrementally.
98
            2) Evaluate stop criteria.
99

100
101
        Return matched stop string or None.
        """
102
        if not new_token_ids:
103
            # Skip detokenization if no new token ids.
104
105
            return None

106
107
108
109
110
111
112
113
        if stop_terminated and not self.include_stop_str_in_output:
            # If stop-terminated, exclude last token from detokenization
            # based on include_stop_str_in_output parameter.
            skipped_stop_token_id = new_token_ids[-1]
            new_token_ids = new_token_ids[:-1]
        else:
            skipped_stop_token_id = None

114
115
116
        # 1) Detokenize the new token ids incrementally.
        # TODO(woosuk): This method becomes very inefficient when the number of
        # new_token_ids is more than 1. We need to optimize this.
117
        stop_check_offset = len(self.output_text)
118
119
        for new_token_id in new_token_ids:
            self.token_ids.append(new_token_id)
120
            self.output_text += self.decode_next(new_token_id)
121
            # Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
122
            if self.min_tokens and len(self.output_token_ids) <= self.min_tokens:
123
                stop_check_offset = len(self.output_text)
124

125
126
127
        if skipped_stop_token_id is not None:
            # Cleanup after skipping detokenization.
            self.token_ids.append(skipped_stop_token_id)
128
129

        # 2) Evaluate stop strings.
130
        stop_string = None
131
        if self.stop and len(self.output_token_ids) > self.min_tokens:
132
            stop = check_stop_strings(
133
                output_text=self.output_text,
134
                new_char_count=len(self.output_text) - stop_check_offset,
135
136
137
138
                stop=self.stop,
                include_in_output=self.include_stop_str_in_output,
            )
            if stop is not None:
139
                stop_string, truncate_to = stop
140
141
142
                if truncate_to != -1:
                    self.output_text = self.output_text[:truncate_to]

143
        return stop_string
144

145
146
147
148
    @abstractmethod
    def decode_next(self, next_token_id: int) -> str:
        raise NotImplementedError

149
    def get_next_output_text(self, finished: bool, delta: bool) -> str:
150
151
152
153
154
155
        """If delta is True, only new text since the last call to
        this method is returned"""

        # We return the full output text if the sequence is finished.
        buffer_length = 0 if finished else self.stop_buffer_length
        if not delta:
156
157
158
159
160
            return (
                self.output_text[:-buffer_length]
                if buffer_length
                else (self.output_text)
            )
161
162
163
164
165
166
        length = len(self.output_text) - buffer_length
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""
167
168
169


class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
170
    def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest):
171
172
173
        super().__init__(request)

        sampling_params = request.sampling_params
174
        assert sampling_params is not None
175
176
177

        self.request_id = request.request_id
        self.skip_special_tokens = sampling_params.skip_special_tokens
178
        self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
179
180
181
182

        self.tokenizer: Tokenizer = tokenizer._tokenizer

        # Find a safe place to start.
183
184
        prompt_token_ids = request.prompt_token_ids or []
        prompt_suffix = prompt_token_ids
185
186
        prompt_len = len(prompt_suffix)
        if prompt_len > 4:
187
            for i in range(4, min(prompt_len + 1, 24)):
188
                suffix = prompt_token_ids[-i:]
189
                if "�" not in self.tokenizer.decode(suffix):
190
191
192
193
194
                    prompt_suffix = suffix
                    break

        # Prime the stream.
        for tid in prompt_suffix:
195
            self._protected_step(tid)
196
197
198

        self.spaces_between_special_tokens = (
            sampling_params.skip_special_tokens
199
200
            or sampling_params.spaces_between_special_tokens
        )
201
202
203
204

        if not self.spaces_between_special_tokens:
            # Store dict of added token ids so that we can suppress
            # the spaces between them.
205
206
207
            if (
                added_token_ids := getattr(self.tokenizer, "added_token_ids", None)
            ) is None:
208
209
                self.tokenizer.added_token_ids = added_token_ids = {
                    tid: tok.content
210
                    for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
211
212
213
214
215
216
217
218
219
220
                }

            if added_token_ids:
                self.last_special = False
                self.added_token_ids = added_token_ids
            else:
                # No added tokens.
                self.spaces_between_special_tokens = True

    def decode_next(self, next_token_id: int) -> str:
221
        token = self._protected_step(next_token_id)
222
223
224
225
226
227
228
229
230
231
232

        if not self.spaces_between_special_tokens:
            special_token = self.added_token_ids.get(next_token_id)
            is_special = special_token is not None
            if is_special and self.last_special:
                # Return raw token string without any prefixed spaces.
                token = special_token
            self.last_special = is_special

        return token or ""

233
    def _protected_step(self, next_token_id: int) -> str | None:
234
235
        try:
            token = self.stream.step(self.tokenizer, next_token_id)
236
        except (OverflowError, TypeError):
237
238
            # Handle rare observed overflow, still to be diagnosed.
            # See https://github.com/vllm-project/vllm/issues/21951.
239
            logger.exception("Encountered invalid token id: %r", next_token_id)
240
            token = None
241
        except Exception as e:
242
            if not str(e).startswith(INVALID_PREFIX_ERR_MSG):
243
244
245
246
247
248
249
                raise e
            # Recover from edge case where tokenizer can produce non-monotonic,
            # invalid UTF-8 output, which breaks the internal state of
            # tokenizers' DecodeStream.
            # See https://github.com/vllm-project/vllm/issues/17448.
            logger.warning(
                "Encountered invalid prefix detokenization error"
250
251
252
253
                " for request %s, resetting decode stream.",
                self.request_id,
            )
            self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
254
255
256
            token = self.stream.step(self.tokenizer, next_token_id)
        return token

257
258

class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
259
    def __init__(self, tokenizer: TokenizerLike, request: EngineCoreRequest):
260
261
262
        super().__init__(request)

        self.tokenizer = tokenizer
263
264
        params = request.sampling_params
        assert params is not None
265

266
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
267
268
            request.prompt_token_ids, request.prompt_embeds
        )
269

270
        # Metadata for incremental detokenization.
271
272
273
274
275
276
        if request.prompt_token_ids is not None:
            self.tokens, self.prefix_offset, self.read_offset = (
                convert_prompt_ids_to_tokens(
                    tokenizer=tokenizer,
                    prompt_ids=request.prompt_token_ids,
                    skip_special_tokens=params.skip_special_tokens,
277
278
                )
            )
279
280
281
282
283
284
        else:
            # Prompt embedding requests cannot be detokenized, in general.
            self.tokens = [""] * self.prompt_len
            self.prefix_offset = 0
            self.read_offest = 0

285
        self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len)
286
287

        self.skip_special_tokens = params.skip_special_tokens
288
        self.spaces_between_special_tokens = params.spaces_between_special_tokens
289
290
291

    @property
    def output_token_ids(self) -> list[int]:
292
293
294
295
296
        return (
            self.token_ids
            if not self.prompt_len
            else (self.token_ids[self.prompt_len :])
        )
297
298

    def decode_next(self, next_token_id: int) -> str:
299
300
301
302
303
304
305
306
307
        new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally(
            tokenizer=self.tokenizer,
            all_input_ids=self.token_ids,
            prev_tokens=self.tokens,
            prefix_offset=self.prefix_offset,
            read_offset=self.read_offset,
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
        )
308
309
310
311
312
313

        self.tokens.extend(new_tokens)
        self.prefix_offset = prefix_offset
        self.read_offset = read_offset

        return decoded_text
314
315
316
317
318
319
320


def check_stop_strings(
    output_text: str,
    new_char_count: int,
    stop: list[str],
    include_in_output: bool,
321
) -> tuple[str, int] | None:
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    """Check if any stop strings are matched and truncate sequence
    output text accordingly.

    Returns tuple (stop_string, offset) if matched or else None.

    Where stop_string is the matched stop string and offset is the
    length to which output_text should be truncated, or -1 for no
    truncation.
    """
    if not new_char_count or not stop:
        return None

    for stop_str in stop:
        stop_string_len = len(stop_str)
        # Avoid searching already-searched text.
337
        stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len)
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        if stop_index == -1:
            continue

        if include_in_output:
            # Truncate to end of stop string.
            stop_index += stop_string_len
            if stop_index >= len(output_text):
                # No truncation required.
                return stop_str, -1

        # Truncate the output text to either the beginning
        # or end of the stop string.
        return stop_str, stop_index
    return None