"vllm/vscode:/vscode.git/clone" did not exist on "50376faa7b7397f82f9b67d7b6e0770ab189b6c1"
detokenizer.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from typing import Optional
5

6
7
import tokenizers
from packaging import version
8
9
10
11
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast

12
13
14
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
    AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
from vllm.v1.engine import EngineCoreRequest
17
18
19

logger = init_logger(__name__)

20
21
22
23
24
25
26
27
# Only tokenizers >= 0.21.1 supports DecodeStream used for
# FastIncrementalDetokenizer.
USE_FAST_DETOKENIZER = version.parse(
    tokenizers.__version__) >= version.parse("0.21.1")

# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"

28
29
30

class IncrementalDetokenizer:

31
32
    def __init__(self):
        self.token_ids: list[int] = []
33
34

    @property
35
    def output_token_ids(self) -> list[int]:
36
37
38
39
40
41
42
43
44
        return self.token_ids

    def update(self, new_token_ids: list[int],
               stop_terminated: bool) -> Optional[str]:
        self.token_ids.extend(new_token_ids)
        return None

    def get_next_output_text(self, finished: bool, delta: bool) -> str:
        return ""
45
46
47
48

    @classmethod
    def from_new_request(
        cls,
49
        tokenizer: Optional[AnyTokenizer],
50
        request: EngineCoreRequest,
51
52
    ) -> "IncrementalDetokenizer":

53
54
        assert request.sampling_params is not None

55
        if tokenizer is None:
56
57
58
            # No tokenizer => skipping detokenization.
            return IncrementalDetokenizer()

59
60
        if USE_FAST_DETOKENIZER and isinstance(tokenizer,
                                               PreTrainedTokenizerFast):
61
62
63
64
65
66
67
68
69
70
71
            # Fast tokenizer => use tokenizers library DecodeStream.
            return FastIncrementalDetokenizer(tokenizer, request)

        # Fall back to slow python-based incremental detokenization.
        return SlowIncrementalDetokenizer(tokenizer, request)


class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):

    def __init__(self, request: EngineCoreRequest):
        super().__init__()
72

73
74
        # Stop strings
        params = request.sampling_params
75
        assert params is not None
76
        self.stop = stop = params.stop
77
        self.min_tokens = params.min_tokens
78
        self.include_stop_str_in_output = params.include_stop_str_in_output
79
80
81

        # Number of chars to hold back when stop strings are to be excluded
        # from streamed output.
82
83
        if stop and not self.include_stop_str_in_output:
            self.stop_buffer_length = max(len(s) for s in stop) - 1
84
        else:
85
86
87
88
89
            self.stop_buffer_length = 0
        self._last_output_text_offset: int = 0

        # Generation data
        self.output_text = ""
90

91
92
    def update(self, new_token_ids: list[int],
               stop_terminated: bool) -> Optional[str]:
93
94
95
        """
        Update RequestState for the request_id by:
            1) Detokenize the new token ids incrementally.
96
            2) Evaluate stop criteria.
97

98
99
        Return matched stop string or None.
        """
100
        if not new_token_ids:
101
            # Skip detokenization if no new token ids.
102
103
            return None

104
105
106
107
108
109
110
111
        if stop_terminated and not self.include_stop_str_in_output:
            # If stop-terminated, exclude last token from detokenization
            # based on include_stop_str_in_output parameter.
            skipped_stop_token_id = new_token_ids[-1]
            new_token_ids = new_token_ids[:-1]
        else:
            skipped_stop_token_id = None

112
113
114
        # 1) Detokenize the new token ids incrementally.
        # TODO(woosuk): This method becomes very inefficient when the number of
        # new_token_ids is more than 1. We need to optimize this.
115
        stop_check_offset = len(self.output_text)
116
117
        for new_token_id in new_token_ids:
            self.token_ids.append(new_token_id)
118
            self.output_text += self.decode_next(new_token_id)
119
120
121
122
            # Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
            if self.min_tokens and len(
                    self.output_token_ids) <= self.min_tokens:
                stop_check_offset = len(self.output_text)
123

124
125
126
        if skipped_stop_token_id is not None:
            # Cleanup after skipping detokenization.
            self.token_ids.append(skipped_stop_token_id)
127
128

        # 2) Evaluate stop strings.
129
        stop_string = None
130
        if self.stop and len(self.output_token_ids) > self.min_tokens:
131
            stop = check_stop_strings(
132
                output_text=self.output_text,
133
                new_char_count=len(self.output_text) - stop_check_offset,
134
135
136
137
                stop=self.stop,
                include_in_output=self.include_stop_str_in_output,
            )
            if stop is not None:
138
                stop_string, truncate_to = stop
139
140
141
                if truncate_to != -1:
                    self.output_text = self.output_text[:truncate_to]

142
        return stop_string
143

144
145
146
147
    @abstractmethod
    def decode_next(self, next_token_id: int) -> str:
        raise NotImplementedError

148
    def get_next_output_text(self, finished: bool, delta: bool) -> str:
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        """If delta is True, only new text since the last call to
        this method is returned"""

        # We return the full output text if the sequence is finished.
        buffer_length = 0 if finished else self.stop_buffer_length
        if not delta:
            return self.output_text[:-buffer_length] if buffer_length else (
                self.output_text)
        length = len(self.output_text) - buffer_length
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""
163
164
165
166
167
168
169
170
171


class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):

    def __init__(self, tokenizer: PreTrainedTokenizerFast,
                 request: EngineCoreRequest):
        super().__init__(request)

        sampling_params = request.sampling_params
172
        assert sampling_params is not None
173
174
175

        self.request_id = request.request_id
        self.skip_special_tokens = sampling_params.skip_special_tokens
176
        self.stream = DecodeStream(
177
            skip_special_tokens=self.skip_special_tokens)
178
179
180
181

        self.tokenizer: Tokenizer = tokenizer._tokenizer

        # Find a safe place to start.
182
183
        prompt_token_ids = request.prompt_token_ids or []
        prompt_suffix = prompt_token_ids
184
185
        prompt_len = len(prompt_suffix)
        if prompt_len > 4:
186
            for i in range(4, min(prompt_len + 1, 24)):
187
                suffix = prompt_token_ids[-i:]
188
189
190
191
192
193
                if '�' not in self.tokenizer.decode(suffix):
                    prompt_suffix = suffix
                    break

        # Prime the stream.
        for tid in prompt_suffix:
194
            self._protected_step(tid)
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

        self.spaces_between_special_tokens = (
            sampling_params.skip_special_tokens
            or sampling_params.spaces_between_special_tokens)

        if not self.spaces_between_special_tokens:
            # Store dict of added token ids so that we can suppress
            # the spaces between them.
            if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
                                           None)) is None:
                self.tokenizer.added_token_ids = added_token_ids = {
                    tid: tok.content
                    for tid, tok in
                    self.tokenizer.get_added_tokens_decoder().items()
                }

            if added_token_ids:
                self.last_special = False
                self.added_token_ids = added_token_ids
            else:
                # No added tokens.
                self.spaces_between_special_tokens = True

    def decode_next(self, next_token_id: int) -> str:
219
        token = self._protected_step(next_token_id)
220
221
222
223
224
225
226
227
228
229
230

        if not self.spaces_between_special_tokens:
            special_token = self.added_token_ids.get(next_token_id)
            is_special = special_token is not None
            if is_special and self.last_special:
                # Return raw token string without any prefixed spaces.
                token = special_token
            self.last_special = is_special

        return token or ""

231
232
233
    def _protected_step(self, next_token_id: int) -> Optional[str]:
        try:
            token = self.stream.step(self.tokenizer, next_token_id)
234
235
236
237
238
        except OverflowError:
            # Handle rare observed overflow, still to be diagnosed.
            # See https://github.com/vllm-project/vllm/issues/21951.
            logger.exception("Encountered invalid token id: %d", next_token_id)
            token = None
239
        except Exception as e:
240
            if not str(e).startswith(INVALID_PREFIX_ERR_MSG):
241
242
243
244
245
246
247
248
                raise e
            # Recover from edge case where tokenizer can produce non-monotonic,
            # invalid UTF-8 output, which breaks the internal state of
            # tokenizers' DecodeStream.
            # See https://github.com/vllm-project/vllm/issues/17448.
            logger.warning(
                "Encountered invalid prefix detokenization error"
                " for request %s, resetting decode stream.", self.request_id)
249
250
            self.stream = DecodeStream(
                skip_special_tokens=self.skip_special_tokens)
251
252
253
            token = self.stream.step(self.tokenizer, next_token_id)
        return token

254
255
256
257
258
259
260

class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):

    def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
        super().__init__(request)

        self.tokenizer = tokenizer
261
262
        params = request.sampling_params
        assert params is not None
263

264
265
266
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
            request.prompt_token_ids, request.prompt_embeds)

267
        # Metadata for incremental detokenization.
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        if request.prompt_token_ids is not None:
            self.tokens, self.prefix_offset, self.read_offset = (
                convert_prompt_ids_to_tokens(
                    tokenizer=tokenizer,
                    prompt_ids=request.prompt_token_ids,
                    skip_special_tokens=params.skip_special_tokens,
                ))
        else:
            # Prompt embedding requests cannot be detokenized, in general.
            self.tokens = [""] * self.prompt_len
            self.prefix_offset = 0
            self.read_offest = 0

        self.token_ids.extend(request.prompt_token_ids
                              or [0] * self.prompt_len)
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

        self.skip_special_tokens = params.skip_special_tokens
        self.spaces_between_special_tokens = (
            params.spaces_between_special_tokens)

    @property
    def output_token_ids(self) -> list[int]:
        return self.token_ids if not self.prompt_len else (
            self.token_ids[self.prompt_len:])

    def decode_next(self, next_token_id: int) -> str:
        new_tokens, decoded_text, prefix_offset, read_offset = (
            detokenize_incrementally(
                tokenizer=self.tokenizer,
                all_input_ids=self.token_ids,
                prev_tokens=self.tokens,
                prefix_offset=self.prefix_offset,
                read_offset=self.read_offset,
                skip_special_tokens=self.skip_special_tokens,
                spaces_between_special_tokens=self.
                spaces_between_special_tokens,
            ))

        self.tokens.extend(new_tokens)
        self.prefix_offset = prefix_offset
        self.read_offset = read_offset

        return decoded_text
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349


def check_stop_strings(
    output_text: str,
    new_char_count: int,
    stop: list[str],
    include_in_output: bool,
) -> Optional[tuple[str, int]]:
    """Check if any stop strings are matched and truncate sequence
    output text accordingly.

    Returns tuple (stop_string, offset) if matched or else None.

    Where stop_string is the matched stop string and offset is the
    length to which output_text should be truncated, or -1 for no
    truncation.
    """
    if not new_char_count or not stop:
        return None

    for stop_str in stop:
        stop_string_len = len(stop_str)
        # Avoid searching already-searched text.
        stop_index = output_text.find(stop_str,
                                      1 - new_char_count - stop_string_len)
        if stop_index == -1:
            continue

        if include_in_output:
            # Truncate to end of stop string.
            stop_index += stop_string_len
            if stop_index >= len(output_text):
                # No truncation required.
                return stop_str, -1

        # Truncate the output text to either the beginning
        # or end of the stop string.
        return stop_str, stop_index
    return None