output_processor.py 22.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, Optional, Union, cast
8

9
10
import torch

11
12
13
14
15
16
from vllm.outputs import (
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
19
from vllm.transformers_utils.tokenizer import AnyTokenizer
20
from vllm.utils import length_from_prompt_token_ids_or_embeds
21
22
23
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
24
from vllm.v1.engine.parallel_sampling import ParentRequest
25
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats
26
27


28
29
30
31
32
33
34
35
36
37
38
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
39
40
41
        self.output: Optional[Union[RequestOutput, PoolingRequestOutput, Exception]] = (
            None
        )
42
43
        self.ready = asyncio.Event()

44
45
46
    def put(
        self, output: Union[RequestOutput, PoolingRequestOutput, Exception]
    ) -> None:
47
48
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
49
50
            self.output = output
            self.ready.set()
51
        elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
52
53
54
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
55

56
    async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
57
        """Get operation blocks on put event."""
58
59
60
61
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
62
63
        if isinstance(output, Exception):
            raise output
64
65
        return output

66
    def get_nowait(self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
67
        """Non-blocking get operation."""
68
69
70
71
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
72
73
        if isinstance(output, Exception):
            raise output
74
75
76
        return output


77
78
@dataclass
class OutputProcessorOutput:
79
    request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
80
    reqs_to_abort: list[str]
81
82
83
84
85
86


class RequestState:
    def __init__(
        self,
        request_id: str,
87
88
        parent_req: Optional[ParentRequest],
        request_index: int,
89
        lora_name: Optional[str],
90
        output_kind: RequestOutputKind,
91
        prompt: Optional[str],
92
93
        prompt_token_ids: Optional[list[int]],
        prompt_embeds: Optional[torch.Tensor],
94
95
        logprobs_processor: Optional[LogprobsProcessor],
        detokenizer: Optional[IncrementalDetokenizer],
96
        max_tokens_param: Optional[int],
97
        arrival_time: float,
98
        queue: Optional[RequestOutputCollector],
99
        log_stats: bool,
100
101
102
        top_p: Optional[float] = None,
        n: Optional[int] = None,
        temperature: Optional[float] = None,
103
104
    ):
        self.request_id = request_id
105
106
        self.parent_req = parent_req
        self.request_index = request_index
107
        self.lora_name = lora_name
108
        self.output_kind = output_kind
109
110
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
111
112
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
113
114
            self.prompt_token_ids, self.prompt_embeds
        )
115
        self.logprobs_processor = logprobs_processor
116
        self.detokenizer = detokenizer
117
        self.max_tokens_param = max_tokens_param
118
119
120
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
121
122
        self.is_prefilling = True
        self.queue = queue
123
        self.num_cached_tokens = 0
124

125
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
126

127
128
129
130
131
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
132
        prompt: Optional[str],
133
134
        parent_req: Optional[ParentRequest],
        request_index: int,
135
        queue: Optional[RequestOutputCollector],
136
        log_stats: bool,
137
    ) -> "RequestState":
138
139
140
141
142
143
144
145
146
147
148
149
150
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
151
152
153
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
154
155
156
157
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
158
159
160
            top_p = None
            n = None
            temperature = None
161
162
163
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

164
165
        return cls(
            request_id=request.request_id,
166
167
            parent_req=parent_req,
            request_index=request_index,
168
169
170
            lora_name=(
                request.lora_request.name if request.lora_request is not None else None
            ),
171
            output_kind=output_kind,
172
            prompt=prompt,
173
            prompt_token_ids=request.prompt_token_ids,
174
            prompt_embeds=request.prompt_embeds,
175
176
177
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
178
179
180
            top_p=top_p,
            n=n,
            temperature=temperature,
181
            arrival_time=request.arrival_time,
182
            queue=queue,
183
            log_stats=log_stats,
184
185
        )

186
187
188
    def make_request_output(
        self,
        new_token_ids: list[int],
189
        pooling_output: Optional[torch.Tensor],
190
191
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
Robert Shaw's avatar
Robert Shaw committed
192
        kv_transfer_params: Optional[dict[str, Any]] = None,
193
    ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
194
        finished = finish_reason is not None
195
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
196

197
        if not finished and final_only:
198
199
200
            # Only the final output is required in FINAL_ONLY mode.
            return None

201
        request_id = self.request_id
202
203
        if pooling_output is not None:
            return self._new_request_output(
204
205
                request_id, [self._new_pooling_output(pooling_output)], finished
            )
206

207
        output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
208

209
        if self.parent_req is None:
210
            outputs = [output]
211
212
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
213
214
                request_id, output
            )
215
216
            if not outputs:
                return None
217

218
219
220
        return self._new_request_output(
            request_id, outputs, finished, kv_transfer_params
        )
221
222
223
224

    def _new_request_output(
        self,
        request_id: str,
225
        outputs: Union[list[CompletionOutput], list[PoolingOutput]],
226
        finished: bool,
Robert Shaw's avatar
Robert Shaw committed
227
        kv_transfer_params: Optional[dict[str, Any]] = None,
228
    ) -> Union[RequestOutput, PoolingRequestOutput]:
229
230
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
231
            assert len(outputs) == 1
232
233
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
234
235
            return PoolingRequestOutput(
                request_id=request_id,
236
                outputs=first_output,
237
238
239
240
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
241
242
243
244
245
246
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

247
248
249
250
251
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

252
253
254
255
256
257
258
259
260
261
262
        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
263
264
265
266
267
268
269

    def _new_completion_output(
        self,
        token_ids: list[int],
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
    ) -> CompletionOutput:
270
271
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
272
273
274
275
276
277
278
279
280
281
282
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
283
            logprobs = logprobs[-len(token_ids) :]
284
285
286
287
288
289
290
291

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
292
293
            stop_reason=stop_reason if finished else None,
        )
294

295
296
297
298
299
300
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:
        return PoolingOutput(data=pooling_output)

301
302
303
304

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

305
    def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
306
307
        self.log_stats = log_stats
        self.tokenizer = tokenizer
308
        self.request_states: dict[str, RequestState] = {}
309
        self.parent_requests: dict[str, ParentRequest] = {}
310
        self.lora_states = LoRARequestStates()
311
        self.tracer: Optional[Tracer] = None
312
313
314
315
316
317
318

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

319
320
321
322
323
324
325
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

326
327
    def abort_requests(
        self,
328
329
330
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
331
        for request_id in request_ids:
332
333
334
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
                self.lora_states.abort_request(req_state)
335
                request_ids_to_abort.append(request_id)
336
337
                # Produce final abort output.
                if req_state.queue is not None and (
338
339
340
341
342
343
344
345
346
347
348
349
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
                        pooling_output=torch.randn(0, device="cpu")
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
350
351
352
353
354
355
356
357
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
358
        return request_ids_to_abort
359
360
361
362

    def add_request(
        self,
        request: EngineCoreRequest,
363
        prompt: Optional[str],
364
365
        parent_req: Optional[ParentRequest] = None,
        request_index: int = 0,
366
        queue: Optional[RequestOutputCollector] = None,
367
368
369
370
371
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

372
373
374
375
376
377
378
379
380
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
        )
381
382
        self.request_states[request_id] = req_state
        self.lora_states.add_request(req_state)
383
384
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
385
386
387

    def process_outputs(
        self,
388
        engine_core_outputs: list[EngineCoreOutput],
389
        engine_core_timestamp: Optional[float] = None,
390
        iteration_stats: Optional[IterationStats] = None,
391
392
393
394
395
396
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
397
            * If there is a queue (for usage with AsyncLLM),
398
399
400
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

401
            * If there is no queue (for usage with LLMEngine),
402
403
              return a list of RequestOutput objects.

404
        NOTE FOR DEVELOPERS
405

406
        vLLM V1 minimizes the number of python loops over the full
407
        batch to ensure system overheads are minimized. This is the
408
409
        only function that should loop over EngineCoreOutputs.

410
411
        If you need to touch every element of the batch, do it from
        within the loop below.
412
413
        """

414
        request_outputs: Union[list[RequestOutput], list[PoolingRequestOutput]] = []
415
        reqs_to_abort: list[str] = []
416
417
418
419
420
421
422
423
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
424
425
426
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
427

428
            new_token_ids = engine_core_output.new_token_ids
429
            pooling_output = engine_core_output.pooling_output
430
            finish_reason = engine_core_output.finish_reason
431
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
432
            kv_transfer_params = engine_core_output.kv_transfer_params
433
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
434
            req_state.is_prefilling = False
435

436
437
438
439
440
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
441
442
                    new_token_ids, finish_reason == FinishReason.STOP
                )
443
444
445
446
447
448
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
449
                req_state.logprobs_processor.update_from_output(engine_core_output)
450
451

            # 4) Create and handle RequestOutput objects.
452
            if request_output := req_state.make_request_output(
453
454
455
456
457
458
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
            ):
459
460
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
461
                    req_state.queue.put(request_output)
462
463
464
465
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

466
467
468
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
469
470
471
472
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
473
474
475
476
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
477

478
                # Track per-request stats
479
480
481
                self._update_stats_from_finished(
                    req_state, finish_reason, iteration_stats
                )
482
                if self.tracer:
483
                    self.do_tracing(engine_core_output, req_state, iteration_stats)
484
485
        self.lora_states.update_iteration_stats(iteration_stats)

486
487
488
489
490
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

491
492
493
494
495
496
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
        iteration_stats: Optional[IterationStats],
    ) -> None:
497
498
499
500
501
502
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
503
        prompt_length = length_from_prompt_token_ids_or_embeds(
504
505
506
507
508
509
510
511
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
512
            metrics = req_state.stats
513
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
514
515
516
517
518
519
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
520
521
                metrics.first_token_latency,
            )
522
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
523
524
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
525
            span.set_attribute(
526
527
528
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
529
            span.set_attribute(
530
531
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
532
            span.set_attribute(
533
534
535
536
537
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
538
539

            # meta
540
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
541
            if req_state.top_p:
542
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
543
            if req_state.max_tokens_param:
544
545
546
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
547
            if req_state.temperature:
548
549
550
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
551
            if req_state.n:
552
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
553

554
555
556
557
558
559
560
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
        engine_core_timestamp: Optional[float],
        iteration_stats: Optional[IterationStats],
    ):
561
562
563
        if iteration_stats is None:
            return

564
565
        lora_stats = self.lora_states.get_stats(req_state)

566
567
        assert engine_core_timestamp is not None
        assert req_state.stats is not None
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
            lora_stats,
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
        finish_reason: Optional[FinishReason],
        iteration_stats: Optional[IterationStats],
    ):
583
584
585
586
587
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
588
589
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
590
            num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
591
592
                req_state.prompt_token_ids, req_state.prompt_embeds
            ),
593
            max_tokens_param=req_state.max_tokens_param,
594
595
            req_stats=req_state.stats,
        )
596
        self.lora_states.finish_request(req_state)
597
598

        ParentRequest.observe_finished_request(
599
600
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )