logprobs.py 8.48 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import itertools
5
from collections.abc import Iterable
6
7
8
from dataclasses import dataclass

from vllm.logger import init_logger
9
10
11
12
13
14
15
from vllm.logprobs import (
    PromptLogprobs,
    SampleLogprobs,
    append_logprobs_for_next_position,
    create_prompt_logprobs,
    create_sample_logprobs,
)
16
from vllm.tokenizers.detokenizer_utils import (
17
    TokenizerLike,
18
19
    convert_ids_list_to_tokens,
)
20
21
22
23
24
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.outputs import LogprobsLists, LogprobsTensors

logger = init_logger(__name__)

25
26
NONES = itertools.repeat(None)

27
28
29

@dataclass
class LogprobsProcessor:
30
31
    # Tokenizer for this request,
    # None if detokenization is disabled.
32
    tokenizer: TokenizerLike | None
33
34

    # Logprobs for this request
35
36
37
38
39
    logprobs: SampleLogprobs | None
    prompt_logprobs: PromptLogprobs | None
    cumulative_logprob: float | None
    num_logprobs: int | None
    num_prompt_logprobs: int | None
40
41
42
43

    @classmethod
    def from_new_request(
        cls,
44
        tokenizer: TokenizerLike | None,
45
46
        request: EngineCoreRequest,
    ) -> "LogprobsProcessor":
47
48
49
50
        sampling_params = request.sampling_params
        assert sampling_params is not None
        num_logprobs = sampling_params.logprobs
        num_prompt_logprobs = sampling_params.prompt_logprobs
51
52
        return cls(
            tokenizer=tokenizer,
53
            cumulative_logprob=(None if num_logprobs is None else 0.0),
54
55
56
57
58
            logprobs=(
                None
                if num_logprobs is None
                else create_sample_logprobs(sampling_params.flat_logprobs)
            ),
59
            prompt_logprobs=(
60
61
62
                None
                if num_prompt_logprobs is None
                else create_prompt_logprobs(sampling_params.flat_logprobs)
63
            ),
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            num_prompt_logprobs=num_prompt_logprobs,
            num_logprobs=num_logprobs,
        )

    def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
        """Update with sample logprobs from EngineCore.

        Outer lists are only of len > 1 if EngineCore made
        >1 tokens in prior step (e.g. in spec decoding).

        Args:
          logprobs_lists: the lists of logprob tokens, logprobs, and ranks.

        """

        assert self.num_logprobs is not None
        assert self.logprobs is not None
        assert self.cumulative_logprob is not None

83
        token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
84

85
86
87
88
89
90
        for rank_np, logprobs_np, token_ids_np in zip(
            ranks_lst, logprobs_lst, token_ids_lst
        ):
            rank = rank_np.tolist()
            logprobs = logprobs_np.tolist()
            token_ids = token_ids_np.tolist()
91
            # Detokenize (non-incrementally).
92
93
94
95
96
97
98
99
100
101
            decoded_tokens: list[str] | Iterable[None]
            if self.tokenizer is None:
                decoded_tokens = NONES
            else:
                decoded_tokens_list = convert_ids_list_to_tokens(
                    self.tokenizer, token_ids
                )
                decoded_tokens = self._verify_tokens(
                    decoded_tokens_list=decoded_tokens_list, tokens=token_ids
                )
102
103
104
105
106

            # Sampler puts the sampled logprob in first.
            sampled_token_logprob = logprobs[0]
            self.cumulative_logprob += sampled_token_logprob

107
108
109
110
111
112
113
114
            # Update with the Logprob container for this pos.
            append_logprobs_for_next_position(
                self.logprobs,
                token_ids,
                logprobs,
                decoded_tokens,
                rank,
                self.num_logprobs,
115
            )
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

    def _update_prompt_logprobs(
        self,
        prompt_logprobs_tensors: LogprobsTensors,
    ) -> None:
        """Update with prompt logprobs from EngineCore.

        Args:
          prompt_logprobs_tensors: tuple containing the prompt logprobs
                                   tensors.

        """

        # Prompt logprobs are enabled.
        assert self.num_prompt_logprobs is not None
        assert self.prompt_logprobs is not None

133
        token_ids, logprobs, ranks, _ = prompt_logprobs_tensors
134

135
136
137
        # Recover shapes.
        num_prompt_tokens, num_logprobs = logprobs.shape

138
139
        # Detokenize non-incrementally.
        # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
140
        all_decoded_tokens: list[str] | None = (
141
142
            None
            if self.tokenizer is None
143
144
            else convert_ids_list_to_tokens(
                self.tokenizer, token_ids.flatten().tolist()
145
146
            )
        )
147
148
149
150

        # Pythonize the torch tensors.
        prompt_token_ranks = ranks.tolist()
        prompt_logprobs = logprobs.tolist()
151
        token_ids_list = token_ids.tolist()
152
153
154

        # Make Logprob for each position.
        for pos in range(num_prompt_tokens):
155
            # Handle flattening and UTF-8 correction per position
156
157
            offset = pos * num_logprobs
            offset_end = offset + num_logprobs
158
159
160
161
162
163
164
165
166
167
168

            decoded_tokens_for_pos: list[str] | Iterable[None]
            if all_decoded_tokens is None:
                decoded_tokens_for_pos = NONES
            else:
                # Extract decoded tokens for this position
                decoded_tokens_slice = all_decoded_tokens[offset:offset_end]
                # Apply UTF-8 correction within this position's token boundaries
                decoded_tokens_for_pos = self._verify_tokens(
                    decoded_tokens_list=decoded_tokens_slice, tokens=token_ids_list[pos]
                )
169

170
171
172
            # Update with the Logprob container for this pos.
            append_logprobs_for_next_position(
                self.prompt_logprobs,
173
                token_ids_list[pos],
174
175
176
177
                prompt_logprobs[pos],
                decoded_tokens_for_pos,
                prompt_token_ranks[pos],
                self.num_prompt_logprobs,
178
            )
179

180
    def pop_prompt_logprobs(self) -> PromptLogprobs | None:
181
        """Pop and return all request prompt logprobs
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        The logprobs processor aggregates prompt chunk logprobs
        over one or more prefill chunks. This method returns
        all prompt logprobs at once and then forgets them.
        Ensures correct RequestOutputKind.DELTA semantics
        wherein all prompt logprobs are returned at once at
        the end of prefill.

        Returns:
          None if prompt logprobs are disabled for this request.
          List of all prompt logprobs, otherwise.
        """
        plp = self.prompt_logprobs
        if plp:
            self.prompt_logprobs = []
        return plp

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    def _correct_decoded_token(self, idx: int, tokens: list[int]) -> str:
        assert self.tokenizer is not None, "self.tokenizer should not be None"

        # try with prev token id in same list
        if idx > 0:
            possible_decoded_token = self.tokenizer.decode(tokens[idx - 1 : idx + 1])
            if not possible_decoded_token.endswith("�"):
                return possible_decoded_token
        # try with previous logprob token id
        if self.logprobs:
            latest_token_id = next(iter(self.logprobs[-1]))

            decode_ids = [latest_token_id]
            if idx > 0:
                decode_ids.extend(tokens[idx - 1 : idx + 1])
            else:
                decode_ids.extend(tokens[idx : idx + 1])

            possible_decoded_token = self.tokenizer.decode(decode_ids)
            if not possible_decoded_token.endswith("�"):
                return possible_decoded_token

        # by default return empty string
        return ""

    def _verify_tokens(
        self, decoded_tokens_list: list[str], tokens: list[int]
    ) -> list[str]:
        corrected_decoded_token_map = dict()
        for idx, text in enumerate(decoded_tokens_list):
            if text.endswith("�"):
                # utf-8 char at the end means it's a potential unfinished byte sequence
                # from byte fallback tokenization.
                corrected_decoded_token_map[idx] = self._correct_decoded_token(
                    idx, tokens
                )

        for idx, text in corrected_decoded_token_map.items():
            decoded_tokens_list[idx] = text

        return decoded_tokens_list

241
242
243
244
245
    def update_from_output(self, output: EngineCoreOutput) -> None:
        if output.new_logprobs is not None:
            self._update_sample_logprobs(output.new_logprobs)
        if output.new_prompt_logprobs_tensors is not None:
            self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)