detokenizer.py 6.35 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from dataclasses import dataclass
4
from typing import List, Optional, Union
5
6
7
8
9
10

from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.detokenizer_utils import (
    AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
11
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
12
13
14
15

logger = init_logger(__name__)


16
17
18
19
20
@dataclass
class DetokenizerOutput:
    output_text: str
    token_ids: List[int]
    finished: bool
21
    finish_reason: Optional[FinishReason] = None
22
23
24
    stop_reason: Union[int, str, None] = None


25
26
27
28
29
30
31
@dataclass
class IncrementalDetokenizer:

    # Generation data
    output_text: str
    tokens: List[str]
    token_ids: List[int]
32
    prompt_len: int
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    # Stop strings
    stop: List[str]
    include_stop_str_in_output: bool

    # Metadata for incremental detokenization
    prefix_offset: int
    read_offset: int

    # Parameters for detokenization
    skip_special_tokens: bool
    spaces_between_special_tokens: bool
    output_kind: RequestOutputKind

    # Tokenizer for this request
    tokenizer: AnyTokenizer

    # Accounting for stop string buffering
    stop_buffer_length: int
    _last_output_text_offset: int = 0

    @property
    def output_token_ids(self) -> List[int]:
56
        return self.token_ids[self.prompt_len:]
57
58
59
60
61

    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
62
        request: EngineCoreRequest,
63
64
65
66
67
    ) -> "IncrementalDetokenizer":

        tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
            tokenizer=tokenizer,
            prompt_ids=request.prompt_token_ids,
68
            skip_special_tokens=request.sampling_params.skip_special_tokens,
69
70
        )

71
        stops = request.sampling_params.stop
72
73
        # Number of chars to hold back when stop strings are to be excluded
        # from streamed output.
74
        if stops and not request.sampling_params.include_stop_str_in_output:
75
76
77
78
79
80
81
82
83
84
85
            stop_buffer_length = max(len(s) for s in stops) - 1
        else:
            stop_buffer_length = 0

        return cls(
            output_text="",
            tokens=tokens,
            # Detokenizer mutates this list, so need a unique copy.
            # NOTE(Nick): could we take ownership of it though?
            token_ids=request.prompt_token_ids.copy(),
            stop=stops,
86
87
            include_stop_str_in_output=request.sampling_params.
            include_stop_str_in_output,
88
89
            prefix_offset=prefix_offset,
            read_offset=read_offset,
90
91
            skip_special_tokens=request.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=request.sampling_params.
92
            spaces_between_special_tokens,
93
            output_kind=request.sampling_params.output_kind,
94
            prompt_len=len(request.prompt_token_ids),
95
96
97
98
            tokenizer=tokenizer,
            stop_buffer_length=stop_buffer_length,
        )

99
    def update_from_output(
100
        self,
101
102
        output: EngineCoreOutput,
    ) -> Optional[DetokenizerOutput]:
103
104
105
106
107
108
        """
        Update RequestState for the request_id by:
            1) Detokenize the new token ids incrementally.
            2) Update the RequestOutput with the new text.
        """

109
110
111
112
        new_token_ids = output.new_token_ids
        finish_reason = output.finish_reason
        stop_reason = output.stop_reason

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        # 1) Detokenize the new token ids incrementally.
        # TODO(woosuk): This method becomes very inefficient when the number of
        # new_token_ids is more than 1. We need to optimize this.
        decoded_text = ""
        for new_token_id in new_token_ids:
            self.token_ids.append(new_token_id)
            (new_tokens, new_decoded_token_text, prefix_offset,
             read_offset) = detokenize_incrementally(
                 tokenizer=self.tokenizer,
                 all_input_ids=self.token_ids,
                 prev_tokens=self.tokens,
                 prefix_offset=self.prefix_offset,
                 read_offset=self.read_offset,
                 skip_special_tokens=self.skip_special_tokens,
                 spaces_between_special_tokens=self.
                 spaces_between_special_tokens,
             )

            self.tokens.extend(new_tokens)
            self.prefix_offset = prefix_offset
            self.read_offset = read_offset
            self.output_text += new_decoded_token_text

            decoded_text += new_decoded_token_text

        # 2) Evaluate stop criteria.
        if self.stop:
            stop = StopChecker.check_stop_strings(
                output_text=self.output_text,
                new_char_count=len(decoded_text),
                stop=self.stop,
                include_in_output=self.include_stop_str_in_output,
            )
            if stop is not None:
                stop_str, truncate_to = stop
                if truncate_to != -1:
                    self.output_text = self.output_text[:truncate_to]
150
                finish_reason = FinishReason.STOP
151
152
153
154
155
                stop_reason = stop_str

        # TODO: handle stop_token_ids here too?

        # 3) Update the RequestOutput object with the new text.
156
        finished = finish_reason is not None
157
158
159
160
161
162
163
164
        if self.output_kind == RequestOutputKind.FINAL_ONLY \
            and not finished:
            return None

        delta = self.output_kind == RequestOutputKind.DELTA
        output_text = self._get_next_output_text(finished, delta)
        token_ids = new_token_ids if delta else self.output_token_ids

165
166
        return DetokenizerOutput(output_text, token_ids, finished,
                                 finish_reason, stop_reason)
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

    def _get_next_output_text(self, finished: bool, delta: bool) -> str:
        """If delta is True, only new text since the last call to
        this method is returned"""

        # We return the full output text if the sequence is finished.
        buffer_length = 0 if finished else self.stop_buffer_length
        if not delta:
            return self.output_text[:-buffer_length] if buffer_length else (
                self.output_text)
        length = len(self.output_text) - buffer_length
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""