output_processor.py 24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, cast
8

9
10
import torch

11
12
13
14
15
16
from vllm.outputs import (
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
19
from vllm.transformers_utils.tokenizer import AnyTokenizer
20
from vllm.utils import length_from_prompt_token_ids_or_embeds
21
22
23
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
24
from vllm.v1.engine.parallel_sampling import ParentRequest
25
26
27
28
29
30
from vllm.v1.metrics.stats import (
    IterationStats,
    LoRARequestStates,
    RequestStateStats,
    SchedulerStats,
)
31
32


33
34
35
36
37
38
39
40
41
42
43
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
44
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
45
46
        self.ready = asyncio.Event()

47
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
48
49
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
50
51
            self.output = output
            self.ready.set()
52
53
54
        elif isinstance(self.output, RequestOutput) and isinstance(
            output, RequestOutput
        ):
55
56
57
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
58
59
60
61
        elif isinstance(self.output, PoolingRequestOutput) and isinstance(
            output, PoolingRequestOutput
        ):
            self.output = output
62

63
    async def get(self) -> RequestOutput | PoolingRequestOutput:
64
        """Get operation blocks on put event."""
65
66
67
68
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
69
70
        if isinstance(output, Exception):
            raise output
71
72
        return output

73
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
74
        """Non-blocking get operation."""
75
76
77
78
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
79
80
        if isinstance(output, Exception):
            raise output
81
82
83
        return output


84
85
@dataclass
class OutputProcessorOutput:
86
    request_outputs: list[RequestOutput | PoolingRequestOutput]
87
    reqs_to_abort: list[str]
88
89
90
91
92
93


class RequestState:
    def __init__(
        self,
        request_id: str,
94
        parent_req: ParentRequest | None,
95
        request_index: int,
96
        lora_name: str | None,
97
        output_kind: RequestOutputKind,
98
99
100
101
102
103
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
104
        arrival_time: float,
105
        queue: RequestOutputCollector | None,
106
        log_stats: bool,
107
        stream_interval: int,
108
109
110
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
111
112
    ):
        self.request_id = request_id
113
114
        self.parent_req = parent_req
        self.request_index = request_index
115
        self.lora_name = lora_name
116
        self.output_kind = output_kind
117
118
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
119
120
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
121
122
            self.prompt_token_ids, self.prompt_embeds
        )
123
        self.logprobs_processor = logprobs_processor
124
        self.detokenizer = detokenizer
125
        self.max_tokens_param = max_tokens_param
126
127
128
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
129
130
        self.is_prefilling = True
        self.queue = queue
131
        self.num_cached_tokens = 0
132

133
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
134

135
136
137
138
        # Stream Interval
        self.stream_interval = stream_interval
        self.sent_tokens_offset = 0  # Offset of sent tokens

139
140
141
142
143
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
144
145
        prompt: str | None,
        parent_req: ParentRequest | None,
146
        request_index: int,
147
        queue: RequestOutputCollector | None,
148
        log_stats: bool,
149
        stream_interval: int,
150
    ) -> "RequestState":
151
152
153
154
155
156
157
158
159
160
161
162
163
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
164
165
166
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
167
168
169
170
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
171
172
173
            top_p = None
            n = None
            temperature = None
174
175
176
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

177
178
        return cls(
            request_id=request.request_id,
179
180
            parent_req=parent_req,
            request_index=request_index,
181
182
183
            lora_name=(
                request.lora_request.name if request.lora_request is not None else None
            ),
184
            output_kind=output_kind,
185
            prompt=prompt,
186
            prompt_token_ids=request.prompt_token_ids,
187
            prompt_embeds=request.prompt_embeds,
188
189
190
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
191
192
193
            top_p=top_p,
            n=n,
            temperature=temperature,
194
            arrival_time=request.arrival_time,
195
            queue=queue,
196
            log_stats=log_stats,
197
            stream_interval=stream_interval,
198
199
        )

200
201
202
    def make_request_output(
        self,
        new_token_ids: list[int],
203
204
205
206
207
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput | None:
208
        finished = finish_reason is not None
209
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
210

211
        if not finished and final_only:
212
213
214
            # Only the final output is required in FINAL_ONLY mode.
            return None

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        if self.stream_interval > 1:
            assert self.detokenizer is not None

            # Send output request only when
            # 1. It has finished, or
            # 2. It is the first token, or
            # 3. It has reached the stream interval number of tokens
            if not (
                finished
                or self.sent_tokens_offset == 0
                or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
                >= self.stream_interval
            ):
                return None

            if self.output_kind == RequestOutputKind.DELTA:
                # Send tokens from the offset in DELTA mode, otherwise all
                # tokens are sent.
                new_token_ids = self.detokenizer.output_token_ids[
                    self.sent_tokens_offset :
                ]
                self.sent_tokens_offset = len(self.detokenizer.output_token_ids)

238
        request_id = self.request_id
239
240
        if pooling_output is not None:
            return self._new_request_output(
241
242
                request_id, [self._new_pooling_output(pooling_output)], finished
            )
243

244
        output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
245

246
        if self.parent_req is None:
247
            outputs = [output]
248
249
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
250
251
                request_id, output
            )
252
253
            if not outputs:
                return None
254

255
256
257
        return self._new_request_output(
            request_id, outputs, finished, kv_transfer_params
        )
258
259
260
261

    def _new_request_output(
        self,
        request_id: str,
262
        outputs: list[CompletionOutput] | list[PoolingOutput],
263
        finished: bool,
264
265
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
266
267
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
268
            assert len(outputs) == 1
269
270
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
271
272
            return PoolingRequestOutput(
                request_id=request_id,
273
                outputs=first_output,
274
                num_cached_tokens=self.num_cached_tokens,
275
276
277
278
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
279
280
281
282
283
284
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

285
286
287
288
289
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

290
291
292
293
294
295
296
297
298
299
300
        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
301
302
303
304

    def _new_completion_output(
        self,
        token_ids: list[int],
305
306
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
307
    ) -> CompletionOutput:
308
309
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
310
311
312
313
314
315
316
317
318
319
320
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
321
            logprobs = logprobs[-len(token_ids) :]
322
323
324
325
326
327
328
329

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
330
331
            stop_reason=stop_reason if finished else None,
        )
332

333
334
335
336
337
338
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:
        return PoolingOutput(data=pooling_output)

339
340
341
342

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

343
344
345
    def __init__(
        self, tokenizer: AnyTokenizer, log_stats: bool, stream_interval: int = 1
    ):
346
347
        self.log_stats = log_stats
        self.tokenizer = tokenizer
348
        self.stream_interval = stream_interval
349
        self.request_states: dict[str, RequestState] = {}
350
        self.parent_requests: dict[str, ParentRequest] = {}
351
        self.lora_states = LoRARequestStates(log_stats)
352
        self.tracer: Tracer | None = None
353
354
355
356
357
358
359

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

360
361
362
363
364
365
366
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

367
368
    def abort_requests(
        self,
369
370
371
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
372
        for request_id in request_ids:
373
374
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
375
                self.lora_states.request_finished(request_id, req_state.lora_name)
376
                request_ids_to_abort.append(request_id)
377
378
                # Produce final abort output.
                if req_state.queue is not None and (
379
380
381
382
383
384
385
386
387
388
389
390
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
                        pooling_output=torch.randn(0, device="cpu")
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
391
392
393
394
395
396
397
398
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
399
        return request_ids_to_abort
400
401
402
403

    def add_request(
        self,
        request: EngineCoreRequest,
404
405
        prompt: str | None,
        parent_req: ParentRequest | None = None,
406
        request_index: int = 0,
407
        queue: RequestOutputCollector | None = None,
408
409
410
411
412
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

413
414
415
416
417
418
419
420
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
421
            stream_interval=self.stream_interval,
422
        )
423
        self.request_states[request_id] = req_state
424
425
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
426
427
428

    def process_outputs(
        self,
429
        engine_core_outputs: list[EngineCoreOutput],
430
431
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
432
433
434
435
436
437
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
438
            * If there is a queue (for usage with AsyncLLM),
439
440
441
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

442
            * If there is no queue (for usage with LLMEngine),
443
444
              return a list of RequestOutput objects.

445
        NOTE FOR DEVELOPERS
446

447
        vLLM V1 minimizes the number of python loops over the full
448
        batch to ensure system overheads are minimized. This is the
449
450
        only function that should loop over EngineCoreOutputs.

451
452
        If you need to touch every element of the batch, do it from
        within the loop below.
453
454
        """

455
        request_outputs: list[RequestOutput | PoolingRequestOutput] = []
456
        reqs_to_abort: list[str] = []
457
458
459
460
461
462
463
464
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
465
466
467
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
468

469
            new_token_ids = engine_core_output.new_token_ids
470
            pooling_output = engine_core_output.pooling_output
471
            finish_reason = engine_core_output.finish_reason
472
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
473
            kv_transfer_params = engine_core_output.kv_transfer_params
474
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
475
            req_state.is_prefilling = False
476

477
478
479
480
481
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
482
483
                    new_token_ids, finish_reason == FinishReason.STOP
                )
484
485
486
487
488
489
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
490
                req_state.logprobs_processor.update_from_output(engine_core_output)
491
492

            # 4) Create and handle RequestOutput objects.
493
            if request_output := req_state.make_request_output(
494
495
496
497
498
499
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
            ):
500
501
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
502
                    req_state.queue.put(request_output)
503
504
505
506
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

507
508
509
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
510
511
512
513
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
514
515
516
517
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
518

519
                # Track per-request stats
520
521
522
                self._update_stats_from_finished(
                    req_state, finish_reason, iteration_stats
                )
523
                if self.tracer:
524
                    self.do_tracing(engine_core_output, req_state, iteration_stats)
525

526
527
528
529
530
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

531
532
533
    def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
        self.lora_states.update_scheduler_stats(scheduler_stats)

534
535
536
537
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
538
        iteration_stats: IterationStats | None,
539
    ) -> None:
540
541
542
543
544
545
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
546
        prompt_length = length_from_prompt_token_ids_or_embeds(
547
548
549
550
551
552
553
554
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
555
            metrics = req_state.stats
556
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
557
558
559
560
561
562
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
563
564
                metrics.first_token_latency,
            )
565
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
566
567
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
568
            span.set_attribute(
569
570
571
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
572
            span.set_attribute(
573
574
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
575
            span.set_attribute(
576
577
578
579
580
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
581
582

            # meta
583
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
584
            if req_state.top_p:
585
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
586
            if req_state.max_tokens_param:
587
588
589
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
590
            if req_state.temperature:
591
592
593
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
594
            if req_state.n:
595
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
596

597
598
599
600
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
601
602
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
603
    ):
604
605
606
607
608
        if iteration_stats is None:
            return

        assert engine_core_timestamp is not None
        assert req_state.stats is not None
609
610
611
612
613
614
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
615
616
            self.lora_states,
            req_state.lora_name,
617
618
619
620
621
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
622
623
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
624
    ):
625
626
627
628
629
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
630
631
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
632
            num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
633
634
                req_state.prompt_token_ids, req_state.prompt_embeds
            ),
635
            max_tokens_param=req_state.max_tokens_param,
636
637
            req_stats=req_state.stats,
        )
638
        self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
639
640

        ParentRequest.observe_finished_request(
641
642
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )