output_processor.py 12 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import asyncio
from dataclasses import dataclass
5
from typing import Optional, Union
6
7

from vllm.outputs import RequestOutput
8
9
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
10
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
11
12
13
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
14
15
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
                                   RequestStateStats)
16
17
18
19
20


@dataclass
class OutputProcessorOutput:

21
22
    request_outputs: list[RequestOutput]
    reqs_to_abort: list[str]
23
24
25
26
27
28
29


class RequestState:

    def __init__(
        self,
        request_id: str,
30
        lora_name: Optional[str],
31
        output_kind: RequestOutputKind,
32
        prompt: Optional[str],
33
        prompt_token_ids: list[int],
34
        logprobs_processor: LogprobsProcessor,
35
        detokenizer: IncrementalDetokenizer,
36
        arrival_time: float,
37
        queue: Optional[asyncio.Queue[RequestOutput]],
38
        log_stats: bool,
39
40
    ):
        self.request_id = request_id
41
        self.lora_name = lora_name
42
        self.output_kind = output_kind
43
44
45
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.prompt_len = len(prompt_token_ids)
46
        self.logprobs_processor = logprobs_processor
47
48
49
50
        self.detokenizer = detokenizer
        self.is_prefilling = True
        self.queue = queue

51
52
        self.stats = RequestStateStats(
            arrival_time=arrival_time) if log_stats else None
53

54
55
56
57
58
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
59
60
        queue: Optional[asyncio.Queue[RequestOutput]],
        log_stats: bool,
61
62
63
    ) -> "RequestState":
        return cls(
            request_id=request.request_id,
64
65
            lora_name=(request.lora_request.name
                       if request.lora_request is not None else None),
66
            output_kind=request.sampling_params.output_kind,
67
68
            prompt=request.prompt,
            prompt_token_ids=request.prompt_token_ids,
69
70
71
72
            logprobs_processor=LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            ),
73
74
75
76
            detokenizer=IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            ),
77
            arrival_time=request.arrival_time,
78
            queue=queue,
79
            log_stats=log_stats,
80
81
82
83
84
85
86
87
88
89
90
91
92
        )


class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

    def __init__(
        self,
        tokenizer: BaseTokenizerGroup,
        log_stats: bool,
    ):
        self.log_stats = log_stats
        self.tokenizer = tokenizer
93
        self.request_states: dict[str, RequestState] = {}
94
        self.lora_states = LoRARequestStates()
95
96
97
98
99
100
101
102
103
104
105
106

    def is_request_active(self, request_id: str) -> bool:
        return request_id in self.request_states

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

    def abort_requests(
        self,
107
        request_ids: list[str],
108
109
    ) -> None:
        for request_id in request_ids:
110
111
112
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
                self.lora_states.abort_request(req_state)
113
114
115
116
117
118
119
120
121
122

    def add_request(
        self,
        request: EngineCoreRequest,
        queue: Optional[asyncio.Queue[RequestOutput]] = None,
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

123
        req_state = RequestState.from_new_request(
124
125
            tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
            request=request,
126
127
            queue=queue,
            log_stats=self.log_stats)
128
129
        self.request_states[request_id] = req_state
        self.lora_states.add_request(req_state)
130
131
132

    def process_outputs(
        self,
133
        engine_core_outputs: list[EngineCoreOutput],
134
        engine_core_timestamp: Optional[float] = None,
135
        iteration_stats: Optional[IterationStats] = None,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
            * If there is a queue (for usage with AsyncLLM), 
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

            * If there is no queue (for usage with LLMEngine), 
              return a list of RequestOutput objects.

        ****************** NOTE FOR DEVELOPERS ******************

        VLLM V1 minimizes the number of python loops over the full
        batch to ensure system overheads are minimized. This is the 
        only function that should loop over EngineCoreOutputs.

155
156
        If you need to touch every element of the batch, do it from
        within the loop below.
157
158
159
160
        
        **********************************************************
        """

161
162
        request_outputs: list[RequestOutput] = []
        reqs_to_abort: list[str] = []
163
164
165
166
167
168
169
170
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
171
172
173
            self._update_stats_from_output(req_state, engine_core_output,
                                           engine_core_timestamp,
                                           iteration_stats)
174

175
176
            new_token_ids = engine_core_output.new_token_ids
            finish_reason = engine_core_output.finish_reason
177
            stop_reason = engine_core_output.stop_reason
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

            # TODO(andy): prompt logprobs + chunked prefill can
            # result in engine core returning an output for a
            # partial prefill (in order to send back partial
            # prompt logprobs.) This breaks the invariant that
            # process_outputs is only operating on engine core
            # outputs associated with non-partial completions.
            # Currently this is handled by having `is_prefilling`
            # check for new decoded tokens, indicating that
            # the completion is not partial.
            #
            # Follow up will aggregate partial prompt logprobs
            # in the EngineCore.
            req_state.is_prefilling = not new_token_ids

            # 2) Detokenize the token ids into text and check for stop
            #    strings.
195
196
            stop_string = req_state.detokenizer.update(new_token_ids)
            if stop_string and finish_reason != FinishReason.STOP:
197
                finish_reason = FinishReason.STOP
198
                stop_reason = stop_string
199
200
201
202
203
204
205
206

            # 3) Compute sample and prompt logprobs for request,
            #    if required.
            req_state.logprobs_processor.update_from_output(engine_core_output)

            # 4) Create and handle RequestOutput objects.
            if request_output := self._make_request_output(
                    req_state, new_token_ids, finish_reason, stop_reason):
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
                    req_state.queue.put_nowait(request_output)
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

                # Free completed requests.
                if request_output.finished:
                    self.request_states.pop(req_id)
                    if not engine_core_output.finished:
                        # If req not finished in EngineCore, but Detokenizer
                        # detected stop string, abort needed in EngineCore.
                        reqs_to_abort.append(req_id)

222
223
224
225
                    # Track per-request stats
                    self._update_stats_from_finished(req_state, request_output,
                                                     finish_reason,
                                                     iteration_stats)
226

227
228
        self.lora_states.update_iteration_stats(iteration_stats)

229
230
231
232
233
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

234
235
236
237
238
239
240
    def _update_stats_from_output(self, req_state: RequestState,
                                  engine_core_output: EngineCoreOutput,
                                  engine_core_timestamp: Optional[float],
                                  iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

241
242
        lora_stats = self.lora_states.get_stats(req_state)

243
244
245
246
247
248
        assert engine_core_timestamp is not None
        assert req_state.stats is not None
        iteration_stats.update_from_output(engine_core_output,
                                           engine_core_timestamp,
                                           req_state.is_prefilling,
                                           req_state.prompt_len,
249
                                           req_state.stats, lora_stats)
250
251
252
253
254
255
256
257
258
259
260
261
262

    def _update_stats_from_finished(self, req_state: RequestState,
                                    request_output: RequestOutput,
                                    finish_reason: Optional[FinishReason],
                                    iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
        iteration_stats.update_from_finished_request(finish_reason,
                                                     request_output,
                                                     req_state.stats)
263
        self.lora_states.finish_request(req_state)
264

265
    @staticmethod
266
267
    def _make_request_output(
        request_state: RequestState,
268
        new_token_ids: list[int],
269
        finish_reason: Optional[FinishReason],
270
        stop_reason: Union[int, str, None],
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    ) -> Optional[RequestOutput]:

        finished = finish_reason is not None
        output_kind = request_state.output_kind
        # In follow up, we will switch to invariant where EngineCore
        # does not stream partial prefills.
        if not finished and (request_state.is_prefilling
                             or output_kind == RequestOutputKind.FINAL_ONLY):
            # Only the final output is required in FINAL_ONLY mode.
            return None

        detokenizer = request_state.detokenizer
        logprobs_processor = request_state.logprobs_processor

        delta = output_kind == RequestOutputKind.DELTA
        logprobs = logprobs_processor.logprobs
        if delta:
            if logprobs:
                logprobs = logprobs[-len(new_token_ids):]
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = logprobs_processor.prompt_logprobs

295
        request_output = RequestOutput.new(
296
297
298
299
300
301
302
303
304
            request_id=request_state.request_id,
            prompt=request_state.prompt,
            prompt_token_ids=request_state.prompt_token_ids,
            text=detokenizer.get_next_output_text(finished, delta),
            token_ids=new_token_ids if delta else detokenizer.output_token_ids,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            cumulative_logprob=logprobs_processor.cumulative_logprob,
            finished=finished,
305
        )
306
        if finished:
307
            completion_output = request_output.outputs[0]
308
309
            completion_output.finish_reason = str(finish_reason)
            completion_output.stop_reason = stop_reason
310
311

        return request_output