detokenizer.py 6.37 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from dataclasses import dataclass, field
4
from typing import Optional
5
6
7
8
9

from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
    AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
10
from vllm.v1.engine import EngineCoreRequest
11
12
13
14
15
16
17
18

logger = init_logger(__name__)


@dataclass
class IncrementalDetokenizer:

    # Generation data
19
    token_ids: list[int]
20
21
22
    output_text: str = ""
    tokens: list[str] = field(default_factory=list)
    prompt_len: int = 0
23
24

    # Stop strings
25
26
    stop: list[str] = field(default_factory=list)
    include_stop_str_in_output: bool = False
27
28

    # Metadata for incremental detokenization
29
30
    prefix_offset: int = 0
    read_offset: int = 0
31
32

    # Parameters for detokenization
33
34
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
35

36
37
38
    # Tokenizer for this request,
    # None if detokenization is disabled.
    tokenizer: Optional[AnyTokenizer] = None
39
40

    # Accounting for stop string buffering
41
    stop_buffer_length: int = 0
42
43
44
    _last_output_text_offset: int = 0

    @property
45
    def output_token_ids(self) -> list[int]:
46
47
        return self.token_ids if not self.prompt_len else (
            self.token_ids[self.prompt_len:])
48
49
50
51

    @classmethod
    def from_new_request(
        cls,
52
        tokenizer: Optional[AnyTokenizer],
53
        request: EngineCoreRequest,
54
55
    ) -> "IncrementalDetokenizer":

56
57
58
        if tokenizer is None:
            return cls(token_ids=[])

59
60
61
        tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
            tokenizer=tokenizer,
            prompt_ids=request.prompt_token_ids,
62
            skip_special_tokens=request.sampling_params.skip_special_tokens,
63
64
        )

65
        stops = request.sampling_params.stop
66
67
        # Number of chars to hold back when stop strings are to be excluded
        # from streamed output.
68
        if stops and not request.sampling_params.include_stop_str_in_output:
69
70
71
72
73
74
75
76
77
78
            stop_buffer_length = max(len(s) for s in stops) - 1
        else:
            stop_buffer_length = 0

        return cls(
            tokens=tokens,
            # Detokenizer mutates this list, so need a unique copy.
            # NOTE(Nick): could we take ownership of it though?
            token_ids=request.prompt_token_ids.copy(),
            stop=stops,
79
80
            include_stop_str_in_output=request.sampling_params.
            include_stop_str_in_output,
81
82
            prefix_offset=prefix_offset,
            read_offset=read_offset,
83
84
            skip_special_tokens=request.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=request.sampling_params.
85
            spaces_between_special_tokens,
86
            prompt_len=len(request.prompt_token_ids),
87
88
89
90
            tokenizer=tokenizer,
            stop_buffer_length=stop_buffer_length,
        )

91
92
    def update(self, new_token_ids: list[int],
               stop_terminated: bool) -> Optional[str]:
93
94
95
        """
        Update RequestState for the request_id by:
            1) Detokenize the new token ids incrementally.
96
            2) Evaluate stop criteria.
97

98
99
        Return matched stop string or None.
        """
100
101
102
        if not new_token_ids:
            # Skip detokenization if no new token ids
            return None
103
        if self.tokenizer is None:
104
            # Skip detokenization if no tokenizer
105
106
107
            self.token_ids.extend(new_token_ids)
            return None

108
109
110
111
112
113
114
115
        if stop_terminated and not self.include_stop_str_in_output:
            # If stop-terminated, exclude last token from detokenization
            # based on include_stop_str_in_output parameter.
            skipped_stop_token_id = new_token_ids[-1]
            new_token_ids = new_token_ids[:-1]
        else:
            skipped_stop_token_id = None

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        # 1) Detokenize the new token ids incrementally.
        # TODO(woosuk): This method becomes very inefficient when the number of
        # new_token_ids is more than 1. We need to optimize this.
        decoded_text = ""
        for new_token_id in new_token_ids:
            self.token_ids.append(new_token_id)
            (new_tokens, new_decoded_token_text, prefix_offset,
             read_offset) = detokenize_incrementally(
                 tokenizer=self.tokenizer,
                 all_input_ids=self.token_ids,
                 prev_tokens=self.tokens,
                 prefix_offset=self.prefix_offset,
                 read_offset=self.read_offset,
                 skip_special_tokens=self.skip_special_tokens,
                 spaces_between_special_tokens=self.
                 spaces_between_special_tokens,
             )

            self.tokens.extend(new_tokens)
            self.prefix_offset = prefix_offset
            self.read_offset = read_offset

            decoded_text += new_decoded_token_text

140
141
        self.output_text += decoded_text

142
143
144
145
146
147
148
149
        if stop_terminated:
            if skipped_stop_token_id is not None:
                # Cleanup after skipping detokenization
                self.token_ids.append(skipped_stop_token_id)
            # Stop token triggered; skip stop string check
            return None

        # 2) Evaluate stop strings.
150
        stop_string = None
151
152
153
154
155
156
157
158
        if self.stop:
            stop = StopChecker.check_stop_strings(
                output_text=self.output_text,
                new_char_count=len(decoded_text),
                stop=self.stop,
                include_in_output=self.include_stop_str_in_output,
            )
            if stop is not None:
159
                stop_string, truncate_to = stop
160
161
162
                if truncate_to != -1:
                    self.output_text = self.output_text[:truncate_to]

163
        return stop_string
164

165
    def get_next_output_text(self, finished: bool, delta: bool) -> str:
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        """If delta is True, only new text since the last call to
        this method is returned"""

        # We return the full output text if the sequence is finished.
        buffer_length = 0 if finished else self.stop_buffer_length
        if not delta:
            return self.output_text[:-buffer_length] if buffer_length else (
                self.output_text)
        length = len(self.output_text) - buffer_length
        last_offset = self._last_output_text_offset
        if last_offset < length:
            self._last_output_text_offset = length
            return self.output_text[last_offset:length]
        return ""