output_processor.py 18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, Optional, Union, cast
8

9
10
11
12
import torch

from vllm.outputs import (CompletionOutput, PoolingOutput,
                          PoolingRequestOutput, RequestOutput)
13
14
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
15
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
16
17
18
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
19
from vllm.v1.engine.parallel_sampling import ParentRequest
20
21
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
                                   RequestStateStats)
22
23


24
25
26
27
28
29
30
31
32
33
34
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
35
36
        self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
                                    Exception]] = None
37
38
        self.ready = asyncio.Event()

39
40
    def put(self, output: Union[RequestOutput, PoolingRequestOutput,
                                Exception]) -> None:
41
42
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
43
44
            self.output = output
            self.ready.set()
45
        elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
46
47
48
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
49

50
    async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
51
        """Get operation blocks on put event."""
52
53
54
55
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
56
57
        if isinstance(output, Exception):
            raise output
58
59
        return output

60
61
    def get_nowait(
            self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
62
        """Non-blocking get operation."""
63
64
65
66
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
67
68
        if isinstance(output, Exception):
            raise output
69
70
71
        return output


72
73
74
@dataclass
class OutputProcessorOutput:

75
    request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
76
    reqs_to_abort: list[str]
77
78
79
80
81
82
83


class RequestState:

    def __init__(
        self,
        request_id: str,
84
85
        parent_req: Optional[ParentRequest],
        request_index: int,
86
        lora_name: Optional[str],
87
        output_kind: RequestOutputKind,
88
        prompt: Optional[str],
89
        prompt_token_ids: list[int],
90
91
        logprobs_processor: Optional[LogprobsProcessor],
        detokenizer: Optional[IncrementalDetokenizer],
92
        max_tokens_param: Optional[int],
93
        arrival_time: float,
94
        queue: Optional[RequestOutputCollector],
95
        log_stats: bool,
96
97
    ):
        self.request_id = request_id
98
99
        self.parent_req = parent_req
        self.request_index = request_index
100
        self.lora_name = lora_name
101
        self.output_kind = output_kind
102
103
104
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.prompt_len = len(prompt_token_ids)
105
        self.logprobs_processor = logprobs_processor
106
        self.detokenizer = detokenizer
107
        self.max_tokens_param = max_tokens_param
108
109
110
        self.is_prefilling = True
        self.queue = queue

111
112
        self.stats = RequestStateStats(
            arrival_time=arrival_time) if log_stats else None
113

114
115
116
117
118
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
119
        prompt: Optional[str],
120
121
        parent_req: Optional[ParentRequest],
        request_index: int,
122
        queue: Optional[RequestOutputCollector],
123
        log_stats: bool,
124
    ) -> "RequestState":
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

146
147
        return cls(
            request_id=request.request_id,
148
149
            parent_req=parent_req,
            request_index=request_index,
150
151
            lora_name=(request.lora_request.name
                       if request.lora_request is not None else None),
152
            output_kind=output_kind,
153
            prompt=prompt,
154
            prompt_token_ids=request.prompt_token_ids,
155
156
157
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
158
            arrival_time=request.arrival_time,
159
            queue=queue,
160
            log_stats=log_stats,
161
162
        )

163
164
165
    def make_request_output(
        self,
        new_token_ids: list[int],
166
        pooling_output: Optional[torch.Tensor],
167
168
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
Robert Shaw's avatar
Robert Shaw committed
169
        kv_transfer_params: Optional[dict[str, Any]] = None,
170
        num_cached_tokens: int = 0,
171
    ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
172
173

        finished = finish_reason is not None
174
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
175

176
        if not finished and final_only:
177
178
179
            # Only the final output is required in FINAL_ONLY mode.
            return None

180
        request_id = self.request_id
181
182
183
184
185
186
187
188
        if pooling_output is not None:
            return self._new_request_output(
                request_id, [self._new_pooling_output(pooling_output)],
                finished)

        output = self._new_completion_output(new_token_ids, finish_reason,
                                             stop_reason)

189
        if self.parent_req is None:
190
            outputs = [output]
191
192
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
193
                request_id, output)
194
195
            if not outputs:
                return None
196

Robert Shaw's avatar
Robert Shaw committed
197
        return self._new_request_output(request_id, outputs, finished,
198
                                        kv_transfer_params, num_cached_tokens)
199
200
201
202

    def _new_request_output(
        self,
        request_id: str,
203
        outputs: Union[list[CompletionOutput], list[PoolingOutput]],
204
        finished: bool,
Robert Shaw's avatar
Robert Shaw committed
205
        kv_transfer_params: Optional[dict[str, Any]] = None,
206
        num_cached_tokens: int = 0,
207
208
209
210
211
212
213
214
215
216
217
    ) -> Union[RequestOutput, PoolingRequestOutput]:

        if isinstance(outputs[0], PoolingOutput):
            assert len(outputs) == 1
            return PoolingRequestOutput(
                request_id=request_id,
                outputs=outputs[0],
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
218
219
220
221
222
223
224
225
226
227
228
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=self.prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
229
            outputs=cast(list[CompletionOutput], outputs),
230
            finished=finished,
Robert Shaw's avatar
Robert Shaw committed
231
            kv_transfer_params=kv_transfer_params,
232
            num_cached_tokens=num_cached_tokens,
233
234
235
236
237
238
239
240
241
        )

    def _new_completion_output(
        self,
        token_ids: list[int],
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
    ) -> CompletionOutput:

242
243
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
            logprobs = logprobs[-len(token_ids):]

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
            stop_reason=stop_reason if finished else None)

266
267
268
269
270
271
272
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:

        return PoolingOutput(data=pooling_output)

273
274
275
276
277
278

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

    def __init__(
        self,
279
        tokenizer: TokenizerGroup,
280
281
282
283
        log_stats: bool,
    ):
        self.log_stats = log_stats
        self.tokenizer = tokenizer
284
        self.request_states: dict[str, RequestState] = {}
285
        self.parent_requests: dict[str, ParentRequest] = {}
286
        self.lora_states = LoRARequestStates()
287
288
289
290
291
292
293

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

294
295
296
297
298
299
300
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

301
302
    def abort_requests(
        self,
303
304
305
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
306
        for request_id in request_ids:
307
308
309
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
                self.lora_states.abort_request(req_state)
310
311
312
313
314
315
316
                request_ids_to_abort.append(request_id)
            else:
                parent = self.parent_requests.pop(request_id, None)
                if parent and parent.child_requests:
                    self.abort_requests(parent.child_requests)
                    request_ids_to_abort.extend(parent.child_requests)
        return request_ids_to_abort
317
318
319
320

    def add_request(
        self,
        request: EngineCoreRequest,
321
        prompt: Optional[str],
322
323
        parent_req: Optional[ParentRequest] = None,
        request_index: int = 0,
324
        queue: Optional[RequestOutputCollector] = None,
325
326
327
328
329
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

330
        req_state = RequestState.from_new_request(
331
332
            tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
            request=request,
333
            prompt=prompt,
334
335
            parent_req=parent_req,
            request_index=request_index,
336
337
            queue=queue,
            log_stats=self.log_stats)
338
339
        self.request_states[request_id] = req_state
        self.lora_states.add_request(req_state)
340
341
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
342
343
344

    def process_outputs(
        self,
345
        engine_core_outputs: list[EngineCoreOutput],
346
        engine_core_timestamp: Optional[float] = None,
347
        iteration_stats: Optional[IterationStats] = None,
348
349
350
351
352
353
354
355
356
357
358
359
360
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
            * If there is a queue (for usage with AsyncLLM), 
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

            * If there is no queue (for usage with LLMEngine), 
              return a list of RequestOutput objects.

361
        NOTE FOR DEVELOPERS
362

363
        vLLM V1 minimizes the number of python loops over the full
364
365
366
        batch to ensure system overheads are minimized. This is the 
        only function that should loop over EngineCoreOutputs.

367
368
        If you need to touch every element of the batch, do it from
        within the loop below.
369
370
        """

371
372
        request_outputs: Union[list[RequestOutput],
                               list[PoolingRequestOutput]] = []
373
        reqs_to_abort: list[str] = []
374
375
376
377
378
379
380
381
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
382
383
384
            self._update_stats_from_output(req_state, engine_core_output,
                                           engine_core_timestamp,
                                           iteration_stats)
385

386
            new_token_ids = engine_core_output.new_token_ids
387
            pooling_output = engine_core_output.pooling_output
388
            finish_reason = engine_core_output.finish_reason
389
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
390
            kv_transfer_params = engine_core_output.kv_transfer_params
391
            num_cached_tokens = engine_core_output.num_cached_tokens
392
            req_state.is_prefilling = False
393

394
395
396
397
398
399
400
401
402
403
404
405
406
407
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
                    new_token_ids, finish_reason == FinishReason.STOP)
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
                req_state.logprobs_processor.update_from_output(
                    engine_core_output)
408
409

            # 4) Create and handle RequestOutput objects.
410
            if request_output := req_state.make_request_output(
411
                    new_token_ids, pooling_output, finish_reason, stop_reason,
412
                    kv_transfer_params, num_cached_tokens):
413
414
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
415
                    req_state.queue.put(request_output)
416
417
418
419
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

420
421
422
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
423
424
425
426
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
427
428
429
430
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
431

432
433
434
                # Track per-request stats
                self._update_stats_from_finished(req_state, finish_reason,
                                                 iteration_stats)
435

436
437
        self.lora_states.update_iteration_stats(iteration_stats)

438
439
440
441
442
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

443
444
445
446
447
448
449
    def _update_stats_from_output(self, req_state: RequestState,
                                  engine_core_output: EngineCoreOutput,
                                  engine_core_timestamp: Optional[float],
                                  iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

450
451
        lora_stats = self.lora_states.get_stats(req_state)

452
453
454
455
456
457
        assert engine_core_timestamp is not None
        assert req_state.stats is not None
        iteration_stats.update_from_output(engine_core_output,
                                           engine_core_timestamp,
                                           req_state.is_prefilling,
                                           req_state.prompt_len,
458
                                           req_state.stats, lora_stats)
459
460
461
462
463
464
465
466
467

    def _update_stats_from_finished(self, req_state: RequestState,
                                    finish_reason: Optional[FinishReason],
                                    iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
468
469
470
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
            num_prompt_tokens=len(req_state.prompt_token_ids),
471
            max_tokens_param=req_state.max_tokens_param,
472
            req_stats=req_state.stats)
473
        self.lora_states.finish_request(req_state)
474
475
476
477

        ParentRequest.observe_finished_request(
            req_state.parent_req, iteration_stats,
            req_state.stats.num_generation_tokens)