"vllm/model_executor/models/ernie45_vl.py" did not exist on "47c7126213163e454019efc7d913125c22df9d6e"
detokenizer.py 9.33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
from abc import ABC, abstractmethod
3
from typing import Optional
4

5
6
import tokenizers
from packaging import version
7
8
9
10
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from transformers import PreTrainedTokenizerFast

11
12
13
14
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
    AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
15
from vllm.v1.engine import EngineCoreRequest
16
17
18
19
20
21

logger = init_logger(__name__)


class IncrementalDetokenizer:

22
23
    def __init__(self):
        self.token_ids: list[int] = []
24
25

    @property
26
    def output_token_ids(self) -> list[int]:
27
28
29
30
31
32
33
34
35
        return self.token_ids

    def update(self, new_token_ids: list[int],
               stop_terminated: bool) -> Optional[str]:
        self.token_ids.extend(new_token_ids)
        return None

    def get_next_output_text(self, finished: bool, delta: bool) -> str:
        return ""
36
37
38
39

    @classmethod
    def from_new_request(
        cls,
40
        tokenizer: Optional[AnyTokenizer],
41
        request: EngineCoreRequest,
42
43
    ) -> "IncrementalDetokenizer":

44
        if tokenizer is None:
45
46
47
            # No tokenizer => skipping detokenization.
            return IncrementalDetokenizer()

48
49
        if (isinstance(tokenizer, PreTrainedTokenizerFast) and version.parse(
                tokenizers.__version__) >= version.parse("0.21.1")):
50
            # Fast tokenizer => use tokenizers library DecodeStream.
51
            # And only tokenizers >= 0.21.1 supports Fast Detokenizer.
52
53
54
55
56
57
58
59
60
61
            return FastIncrementalDetokenizer(tokenizer, request)

        # Fall back to slow python-based incremental detokenization.
        return SlowIncrementalDetokenizer(tokenizer, request)


class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):

    def __init__(self, request: EngineCoreRequest):
        super().__init__()
62

63
64
65
66
        # Stop strings
        params = request.sampling_params
        self.stop = stop = params.stop
        self.include_stop_str_in_output = params.include_stop_str_in_output
67
68
69

        # Number of chars to hold back when stop strings are to be excluded
        # from streamed output.
70
71
        if stop and not self.include_stop_str_in_output:
            self.stop_buffer_length = max(len(s) for s in stop) - 1
72
        else:
73
74
75
76
77
            self.stop_buffer_length = 0
        self._last_output_text_offset: int = 0

        # Generation data
        self.output_text = ""
78

79
80
    def update(self, new_token_ids: list[int],
               stop_terminated: bool) -> Optional[str]:
81
82
83
        """
        Update RequestState for the request_id by:
            1) Detokenize the new token ids incrementally.
84
            2) Evaluate stop criteria.
85

86
87
        Return matched stop string or None.
        """
88
        if not new_token_ids:
89
            # Skip detokenization if no new token ids.
90
91
            return None

92
93
94
95
96
97
98
99
        if stop_terminated and not self.include_stop_str_in_output:
            # If stop-terminated, exclude last token from detokenization
            # based on include_stop_str_in_output parameter.
            skipped_stop_token_id = new_token_ids[-1]
            new_token_ids = new_token_ids[:-1]
        else:
            skipped_stop_token_id = None

100
101
102
        # 1) Detokenize the new token ids incrementally.
        # TODO(woosuk): This method becomes very inefficient when the number of
        # new_token_ids is more than 1. We need to optimize this.
103
        offset_before = len(self.output_text)
104
105
        for new_token_id in new_token_ids:
            self.token_ids.append(new_token_id)
106
            self.output_text += self.decode_next(new_token_id)
107

108
109
        if stop_terminated:
            if skipped_stop_token_id is not None:
110
                # Cleanup after skipping detokenization.
111
                self.token_ids.append(skipped_stop_token_id)
112
            # Stop token triggered; skip stop string check.
113
114
115
            return None

        # 2) Evaluate stop strings.
116
        stop_string = None
117
118
119
        if self.stop:
            stop = StopChecker.check_stop_strings(
                output_text=self.output_text,
120
                new_char_count=len(self.output_text) - offset_before,
121
122
123
124
                stop=self.stop,
                include_in_output=self.include_stop_str_in_output,
            )
            if stop is not None:
125
                stop_string, truncate_to = stop
126
127
128
                if truncate_to != -1:
                    self.output_text = self.output_text[:truncate_to]

129
        return stop_string
130

131
132
133
134
    @abstractmethod
    def decode_next(self, next_token_id: int) -> str:
        raise NotImplementedError

135
    def get_next_output_text(self, finished: bool, delta: bool) -> str:
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        """If delta is True, only new text since the last call to
        this method is returned"""

        # We return the full output text if the sequence is finished.
        buffer_length = 0 if finished else self.stop_buffer_length
        if not delta:
            return self.output_text[:-buffer_length] if buffer_length else (
                self.output_text)
        length = len(self.output_text) - buffer_length
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167


class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):

    def __init__(self, tokenizer: PreTrainedTokenizerFast,
                 request: EngineCoreRequest):
        super().__init__(request)

        sampling_params = request.sampling_params
        self.stream = DecodeStream(
            skip_special_tokens=sampling_params.skip_special_tokens)

        self.tokenizer: Tokenizer = tokenizer._tokenizer

        # Find a safe place to start.
        prompt_suffix = request.prompt_token_ids
        prompt_len = len(prompt_suffix)
        if prompt_len > 4:
168
            for i in range(4, min(prompt_len + 1, 24)):
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
                suffix = request.prompt_token_ids[-i:]
                if '�' not in self.tokenizer.decode(suffix):
                    prompt_suffix = suffix
                    break

        # Prime the stream.
        for tid in prompt_suffix:
            self.stream.step(self.tokenizer, tid)

        self.spaces_between_special_tokens = (
            sampling_params.skip_special_tokens
            or sampling_params.spaces_between_special_tokens)

        if not self.spaces_between_special_tokens:
            # Store dict of added token ids so that we can suppress
            # the spaces between them.
            if (added_token_ids := getattr(self.tokenizer, "added_token_ids",
                                           None)) is None:
                self.tokenizer.added_token_ids = added_token_ids = {
                    tid: tok.content
                    for tid, tok in
                    self.tokenizer.get_added_tokens_decoder().items()
                }

            if added_token_ids:
                self.last_special = False
                self.added_token_ids = added_token_ids
            else:
                # No added tokens.
                self.spaces_between_special_tokens = True

    def decode_next(self, next_token_id: int) -> str:
        token = self.stream.step(self.tokenizer, next_token_id)

        if not self.spaces_between_special_tokens:
            special_token = self.added_token_ids.get(next_token_id)
            is_special = special_token is not None
            if is_special and self.last_special:
                # Return raw token string without any prefixed spaces.
                token = special_token
            self.last_special = is_special

        return token or ""


class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):

    def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
        super().__init__(request)

        self.tokenizer = tokenizer

        # Metadata for incremental detokenization.
        self.tokens, self.prefix_offset, self.read_offset = (
            convert_prompt_ids_to_tokens(
                tokenizer=tokenizer,
                prompt_ids=request.prompt_token_ids,
                skip_special_tokens=request.sampling_params.
                skip_special_tokens,
            ))

        self.token_ids.extend(request.prompt_token_ids)
        self.prompt_len = len(request.prompt_token_ids)

        params = request.sampling_params
        self.skip_special_tokens = params.skip_special_tokens
        self.spaces_between_special_tokens = (
            params.spaces_between_special_tokens)

    @property
    def output_token_ids(self) -> list[int]:
        return self.token_ids if not self.prompt_len else (
            self.token_ids[self.prompt_len:])

    def decode_next(self, next_token_id: int) -> str:
        new_tokens, decoded_text, prefix_offset, read_offset = (
            detokenize_incrementally(
                tokenizer=self.tokenizer,
                all_input_ids=self.token_ids,
                prev_tokens=self.tokens,
                prefix_offset=self.prefix_offset,
                read_offset=self.read_offset,
                skip_special_tokens=self.skip_special_tokens,
                spaces_between_special_tokens=self.
                spaces_between_special_tokens,
            ))

        self.tokens.extend(new_tokens)
        self.prefix_offset = prefix_offset
        self.read_offset = read_offset

        return decoded_text