detokenizer.py 12.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from typing import Optional
5

6
7
import tokenizers
from packaging import version
8
9
10
11
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast

12
13
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
14
15
16
17
    AnyTokenizer,
    convert_prompt_ids_to_tokens,
    detokenize_incrementally,
)
18
from vllm.utils import length_from_prompt_token_ids_or_embeds
19
from vllm.v1.engine import EngineCoreRequest
20
21
22

logger = init_logger(__name__)

23
24
# Only tokenizers >= 0.21.1 supports DecodeStream used for
# FastIncrementalDetokenizer.
25
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1")
26
27
28
29

# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"

30
31

class IncrementalDetokenizer:
32
33
    def __init__(self):
        self.token_ids: list[int] = []
34
35

    @property
36
    def output_token_ids(self) -> list[int]:
37
38
        return self.token_ids

39
    def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]:
40
41
42
43
44
        self.token_ids.extend(new_token_ids)
        return None

    def get_next_output_text(self, finished: bool, delta: bool) -> str:
        return ""
45
46
47
48

    @classmethod
    def from_new_request(
        cls,
49
        tokenizer: Optional[AnyTokenizer],
50
        request: EngineCoreRequest,
51
    ) -> "IncrementalDetokenizer":
52
53
        assert request.sampling_params is not None

54
        if tokenizer is None:
55
56
57
            # No tokenizer => skipping detokenization.
            return IncrementalDetokenizer()

58
        if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast):
59
60
61
62
63
64
65
66
67
68
            # Fast tokenizer => use tokenizers library DecodeStream.
            return FastIncrementalDetokenizer(tokenizer, request)

        # Fall back to slow python-based incremental detokenization.
        return SlowIncrementalDetokenizer(tokenizer, request)


class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
    def __init__(self, request: EngineCoreRequest):
        super().__init__()
69

70
71
        # Stop strings
        params = request.sampling_params
72
        assert params is not None
73
        self.stop = stop = params.stop
74
        self.min_tokens = params.min_tokens
75
        self.include_stop_str_in_output = params.include_stop_str_in_output
76
77
78

        # Number of chars to hold back when stop strings are to be excluded
        # from streamed output.
79
80
        if stop and not self.include_stop_str_in_output:
            self.stop_buffer_length = max(len(s) for s in stop) - 1
81
        else:
82
83
84
85
86
            self.stop_buffer_length = 0
        self._last_output_text_offset: int = 0

        # Generation data
        self.output_text = ""
87

88
    def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]:
89
90
91
        """
        Update RequestState for the request_id by:
            1) Detokenize the new token ids incrementally.
92
            2) Evaluate stop criteria.
93

94
95
        Return matched stop string or None.
        """
96
        if not new_token_ids:
97
            # Skip detokenization if no new token ids.
98
99
            return None

100
101
102
103
104
105
106
107
        if stop_terminated and not self.include_stop_str_in_output:
            # If stop-terminated, exclude last token from detokenization
            # based on include_stop_str_in_output parameter.
            skipped_stop_token_id = new_token_ids[-1]
            new_token_ids = new_token_ids[:-1]
        else:
            skipped_stop_token_id = None

108
109
110
        # 1) Detokenize the new token ids incrementally.
        # TODO(woosuk): This method becomes very inefficient when the number of
        # new_token_ids is more than 1. We need to optimize this.
111
        stop_check_offset = len(self.output_text)
112
113
        for new_token_id in new_token_ids:
            self.token_ids.append(new_token_id)
114
            self.output_text += self.decode_next(new_token_id)
115
            # Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
116
            if self.min_tokens and len(self.output_token_ids) <= self.min_tokens:
117
                stop_check_offset = len(self.output_text)
118

119
120
121
        if skipped_stop_token_id is not None:
            # Cleanup after skipping detokenization.
            self.token_ids.append(skipped_stop_token_id)
122
123

        # 2) Evaluate stop strings.
124
        stop_string = None
125
        if self.stop and len(self.output_token_ids) > self.min_tokens:
126
            stop = check_stop_strings(
127
                output_text=self.output_text,
128
                new_char_count=len(self.output_text) - stop_check_offset,
129
130
131
132
                stop=self.stop,
                include_in_output=self.include_stop_str_in_output,
            )
            if stop is not None:
133
                stop_string, truncate_to = stop
134
135
136
                if truncate_to != -1:
                    self.output_text = self.output_text[:truncate_to]

137
        return stop_string
138

139
140
141
142
    @abstractmethod
    def decode_next(self, next_token_id: int) -> str:
        raise NotImplementedError

143
    def get_next_output_text(self, finished: bool, delta: bool) -> str:
144
145
146
147
148
149
        """If delta is True, only new text since the last call to
        this method is returned"""

        # We return the full output text if the sequence is finished.
        buffer_length = 0 if finished else self.stop_buffer_length
        if not delta:
150
151
152
153
154
            return (
                self.output_text[:-buffer_length]
                if buffer_length
                else (self.output_text)
            )
155
156
157
158
159
160
        length = len(self.output_text) - buffer_length
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""
161
162
163


class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
164
    def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest):
165
166
167
        super().__init__(request)

        sampling_params = request.sampling_params
168
        assert sampling_params is not None
169
170
171

        self.request_id = request.request_id
        self.skip_special_tokens = sampling_params.skip_special_tokens
172
        self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
173
174
175
176

        self.tokenizer: Tokenizer = tokenizer._tokenizer

        # Find a safe place to start.
177
178
        prompt_token_ids = request.prompt_token_ids or []
        prompt_suffix = prompt_token_ids
179
180
        prompt_len = len(prompt_suffix)
        if prompt_len > 4:
181
            for i in range(4, min(prompt_len + 1, 24)):
182
                suffix = prompt_token_ids[-i:]
183
                if "�" not in self.tokenizer.decode(suffix):
184
185
186
187
188
                    prompt_suffix = suffix
                    break

        # Prime the stream.
        for tid in prompt_suffix:
189
            self._protected_step(tid)
190
191
192

        self.spaces_between_special_tokens = (
            sampling_params.skip_special_tokens
193
194
            or sampling_params.spaces_between_special_tokens
        )
195
196
197
198

        if not self.spaces_between_special_tokens:
            # Store dict of added token ids so that we can suppress
            # the spaces between them.
199
200
201
            if (
                added_token_ids := getattr(self.tokenizer, "added_token_ids", None)
            ) is None:
202
203
                self.tokenizer.added_token_ids = added_token_ids = {
                    tid: tok.content
204
                    for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
205
206
207
208
209
210
211
212
213
214
                }

            if added_token_ids:
                self.last_special = False
                self.added_token_ids = added_token_ids
            else:
                # No added tokens.
                self.spaces_between_special_tokens = True

    def decode_next(self, next_token_id: int) -> str:
215
        token = self._protected_step(next_token_id)
216
217
218
219
220
221
222
223
224
225
226

        if not self.spaces_between_special_tokens:
            special_token = self.added_token_ids.get(next_token_id)
            is_special = special_token is not None
            if is_special and self.last_special:
                # Return raw token string without any prefixed spaces.
                token = special_token
            self.last_special = is_special

        return token or ""

227
228
229
    def _protected_step(self, next_token_id: int) -> Optional[str]:
        try:
            token = self.stream.step(self.tokenizer, next_token_id)
230
231
232
233
234
        except OverflowError:
            # Handle rare observed overflow, still to be diagnosed.
            # See https://github.com/vllm-project/vllm/issues/21951.
            logger.exception("Encountered invalid token id: %d", next_token_id)
            token = None
235
        except Exception as e:
236
            if not str(e).startswith(INVALID_PREFIX_ERR_MSG):
237
238
239
240
241
242
243
                raise e
            # Recover from edge case where tokenizer can produce non-monotonic,
            # invalid UTF-8 output, which breaks the internal state of
            # tokenizers' DecodeStream.
            # See https://github.com/vllm-project/vllm/issues/17448.
            logger.warning(
                "Encountered invalid prefix detokenization error"
244
245
246
247
                " for request %s, resetting decode stream.",
                self.request_id,
            )
            self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
248
249
250
            token = self.stream.step(self.tokenizer, next_token_id)
        return token

251
252
253
254
255
256

class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
    def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
        super().__init__(request)

        self.tokenizer = tokenizer
257
258
        params = request.sampling_params
        assert params is not None
259

260
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
261
262
            request.prompt_token_ids, request.prompt_embeds
        )
263

264
        # Metadata for incremental detokenization.
265
266
267
268
269
270
        if request.prompt_token_ids is not None:
            self.tokens, self.prefix_offset, self.read_offset = (
                convert_prompt_ids_to_tokens(
                    tokenizer=tokenizer,
                    prompt_ids=request.prompt_token_ids,
                    skip_special_tokens=params.skip_special_tokens,
271
272
                )
            )
273
274
275
276
277
278
        else:
            # Prompt embedding requests cannot be detokenized, in general.
            self.tokens = [""] * self.prompt_len
            self.prefix_offset = 0
            self.read_offest = 0

279
        self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len)
280
281

        self.skip_special_tokens = params.skip_special_tokens
282
        self.spaces_between_special_tokens = params.spaces_between_special_tokens
283
284
285

    @property
    def output_token_ids(self) -> list[int]:
286
287
288
289
290
        return (
            self.token_ids
            if not self.prompt_len
            else (self.token_ids[self.prompt_len :])
        )
291
292

    def decode_next(self, next_token_id: int) -> str:
293
294
295
296
297
298
299
300
301
        new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally(
            tokenizer=self.tokenizer,
            all_input_ids=self.token_ids,
            prev_tokens=self.tokens,
            prefix_offset=self.prefix_offset,
            read_offset=self.read_offset,
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
        )
302
303
304
305
306
307

        self.tokens.extend(new_tokens)
        self.prefix_offset = prefix_offset
        self.read_offset = read_offset

        return decoded_text
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330


def check_stop_strings(
    output_text: str,
    new_char_count: int,
    stop: list[str],
    include_in_output: bool,
) -> Optional[tuple[str, int]]:
    """Check if any stop strings are matched and truncate sequence
    output text accordingly.

    Returns tuple (stop_string, offset) if matched or else None.

    Where stop_string is the matched stop string and offset is the
    length to which output_text should be truncated, or -1 for no
    truncation.
    """
    if not new_char_count or not stop:
        return None

    for stop_str in stop:
        stop_string_len = len(stop_str)
        # Avoid searching already-searched text.
331
        stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len)
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        if stop_index == -1:
            continue

        if include_in_output:
            # Truncate to end of stop string.
            stop_index += stop_string_len
            if stop_index >= len(output_text):
                # No truncation required.
                return stop_str, -1

        # Truncate the output text to either the beginning
        # or end of the stop string.
        return stop_str, stop_index
    return None