output_processor.py 9.89 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
import asyncio
from dataclasses import dataclass
from typing import Dict, List, Optional

from vllm.outputs import RequestOutput
8
9
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
10
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
11
12
13
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
14
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


@dataclass
class OutputProcessorOutput:

    request_outputs: List[RequestOutput]
    reqs_to_abort: List[str]
    iteration_stats: IterationStats


class RequestState:

    def __init__(
        self,
        request_id: str,
30
        output_kind: RequestOutputKind,
31
32
        prompt: Optional[str],
        prompt_token_ids: List[int],
33
        logprobs_processor: LogprobsProcessor,
34
        detokenizer: IncrementalDetokenizer,
35
        arrival_time: float,
36
37
38
        queue: Optional[asyncio.Queue[RequestOutput]],
    ):
        self.request_id = request_id
39
        self.output_kind = output_kind
40
41
42
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.prompt_len = len(prompt_token_ids)
43
        self.logprobs_processor = logprobs_processor
44
45
46
47
        self.detokenizer = detokenizer
        self.is_prefilling = True
        self.queue = queue

48
        self.stats = RequestStateStats(last_token_time=arrival_time)
49

50
51
52
53
54
55
56
57
58
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
        queue: Optional[asyncio.Queue[RequestOutput]] = None,
    ) -> "RequestState":
        return cls(
            request_id=request.request_id,
59
            output_kind=request.sampling_params.output_kind,
60
61
            prompt=request.prompt,
            prompt_token_ids=request.prompt_token_ids,
62
63
64
65
            logprobs_processor=LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            ),
66
67
68
69
            detokenizer=IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            ),
70
            arrival_time=request.arrival_time,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            queue=queue,
        )


class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

    def __init__(
        self,
        tokenizer: BaseTokenizerGroup,
        log_stats: bool,
    ):
        self.log_stats = log_stats
        self.tokenizer = tokenizer
        self.request_states: Dict[str, RequestState] = {}

    def is_request_active(self, request_id: str) -> bool:
        return request_id in self.request_states

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

    def abort_requests(
        self,
        request_ids: List[str],
    ) -> None:
        for request_id in request_ids:
            self.request_states.pop(request_id, None)

    def add_request(
        self,
        request: EngineCoreRequest,
        queue: Optional[asyncio.Queue[RequestOutput]] = None,
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

        self.request_states[request_id] = RequestState.from_new_request(
            tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
            request=request,
            queue=queue)

    def process_outputs(
        self,
        engine_core_outputs: List[EngineCoreOutput],
120
        iteration_stats: Optional[IterationStats] = None,
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
            * If there is a queue (for usage with AsyncLLM), 
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

            * If there is no queue (for usage with LLMEngine), 
              return a list of RequestOutput objects.

        ****************** NOTE FOR DEVELOPERS ******************

        VLLM V1 minimizes the number of python loops over the full
        batch to ensure system overheads are minimized. This is the 
        only function that should loop over EngineCoreOutputs.

140
141
        If you need to touch every element of the batch, do it from
        within the loop below.
142
143
144
145
146
147
        
        **********************************************************
        """

        request_outputs: List[RequestOutput] = []
        reqs_to_abort: List[str] = []
148
149
        if not iteration_stats:
            iteration_stats = IterationStats(self.log_stats)
150
151
152
153
154
155
156
157
158
159
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
            iteration_stats.update_from_output(engine_core_output,
                                               req_state.is_prefilling,
160
161
                                               req_state.prompt_len,
                                               req_state.stats)
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
            new_token_ids = engine_core_output.new_token_ids
            finish_reason = engine_core_output.finish_reason

            # TODO(andy): prompt logprobs + chunked prefill can
            # result in engine core returning an output for a
            # partial prefill (in order to send back partial
            # prompt logprobs.) This breaks the invariant that
            # process_outputs is only operating on engine core
            # outputs associated with non-partial completions.
            # Currently this is handled by having `is_prefilling`
            # check for new decoded tokens, indicating that
            # the completion is not partial.
            #
            # Follow up will aggregate partial prompt logprobs
            # in the EngineCore.
            req_state.is_prefilling = not new_token_ids

            # 2) Detokenize the token ids into text and check for stop
            #    strings.
            stop_reason = req_state.detokenizer.update(new_token_ids)
            if stop_reason:
                finish_reason = FinishReason.STOP

            # 3) Compute sample and prompt logprobs for request,
            #    if required.
            req_state.logprobs_processor.update_from_output(engine_core_output)

            # 4) Create and handle RequestOutput objects.
            if request_output := self._make_request_output(
                    req_state, new_token_ids, finish_reason, stop_reason):
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
                    req_state.queue.put_nowait(request_output)
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

                # Free completed requests.
                if request_output.finished:
                    self.request_states.pop(req_id)
                    if not engine_core_output.finished:
                        # If req not finished in EngineCore, but Detokenizer
                        # detected stop string, abort needed in EngineCore.
                        reqs_to_abort.append(req_id)

208
209
                    # Track per-request stats.
                    assert finish_reason is not None
210
                    iteration_stats.update_from_finished_request(
211
                        finish_reason, request_output, req_state.stats)
212

213
214
215
216
217
218
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
            iteration_stats=iteration_stats,
        )

219
    @staticmethod
220
221
    def _make_request_output(
        request_state: RequestState,
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        new_token_ids: List[int],
        finish_reason: Optional[FinishReason],
        stop_reason: Optional[str],
    ) -> Optional[RequestOutput]:

        finished = finish_reason is not None
        output_kind = request_state.output_kind
        # In follow up, we will switch to invariant where EngineCore
        # does not stream partial prefills.
        if not finished and (request_state.is_prefilling
                             or output_kind == RequestOutputKind.FINAL_ONLY):
            # Only the final output is required in FINAL_ONLY mode.
            return None

        detokenizer = request_state.detokenizer
        logprobs_processor = request_state.logprobs_processor

        delta = output_kind == RequestOutputKind.DELTA
        logprobs = logprobs_processor.logprobs
        if delta:
            if logprobs:
                logprobs = logprobs[-len(new_token_ids):]
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = logprobs_processor.prompt_logprobs

249
        request_output = RequestOutput.new(
250
251
252
253
254
255
256
257
258
            request_id=request_state.request_id,
            prompt=request_state.prompt,
            prompt_token_ids=request_state.prompt_token_ids,
            text=detokenizer.get_next_output_text(finished, delta),
            token_ids=new_token_ids if delta else detokenizer.output_token_ids,
            logprobs=logprobs,
            prompt_logprobs=prompt_logprobs,
            cumulative_logprob=logprobs_processor.cumulative_logprob,
            finished=finished,
259
        )
260
        if finished:
261
            completion_output = request_output.outputs[0]
262
263
            completion_output.finish_reason = str(finish_reason)
            completion_output.stop_reason = stop_reason
264
265

        return request_output