output_processor.py 24.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, cast
8

9
10
import torch

11
12
13
14
15
16
from vllm.outputs import (
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
19
from vllm.transformers_utils.tokenizer import AnyTokenizer
20
from vllm.utils import length_from_prompt_token_ids_or_embeds
21
22
23
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
24
from vllm.v1.engine.parallel_sampling import ParentRequest
25
26
27
28
29
30
from vllm.v1.metrics.stats import (
    IterationStats,
    LoRARequestStates,
    RequestStateStats,
    SchedulerStats,
)
31
32


33
34
35
36
37
38
39
40
41
42
43
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
44
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
45
46
        self.ready = asyncio.Event()

47
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
48
49
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
50
51
            self.output = output
            self.ready.set()
52
53
54
        elif isinstance(self.output, RequestOutput) and isinstance(
            output, RequestOutput
        ):
55
56
57
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
58
59
60
61
        elif isinstance(self.output, PoolingRequestOutput) and isinstance(
            output, PoolingRequestOutput
        ):
            self.output = output
62

63
    async def get(self) -> RequestOutput | PoolingRequestOutput:
64
        """Get operation blocks on put event."""
65
66
67
68
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
69
70
        if isinstance(output, Exception):
            raise output
71
72
        return output

73
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
74
        """Non-blocking get operation."""
75
76
77
78
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
79
80
        if isinstance(output, Exception):
            raise output
81
82
83
        return output


84
85
@dataclass
class OutputProcessorOutput:
86
    request_outputs: list[RequestOutput | PoolingRequestOutput]
87
    reqs_to_abort: list[str]
88
89
90
91
92
93


class RequestState:
    def __init__(
        self,
        request_id: str,
94
        parent_req: ParentRequest | None,
95
        request_index: int,
96
        lora_name: str | None,
97
        output_kind: RequestOutputKind,
98
99
100
101
102
103
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
104
        arrival_time: float,
105
        queue: RequestOutputCollector | None,
106
        log_stats: bool,
107
        stream_interval: int,
108
109
110
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
111
112
    ):
        self.request_id = request_id
113
114
        self.parent_req = parent_req
        self.request_index = request_index
115
        self.lora_name = lora_name
116
        self.output_kind = output_kind
117
118
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
119
120
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
121
122
            self.prompt_token_ids, self.prompt_embeds
        )
123
        self.logprobs_processor = logprobs_processor
124
        self.detokenizer = detokenizer
125
        self.max_tokens_param = max_tokens_param
126
127
128
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
129
130
        self.is_prefilling = True
        self.queue = queue
131
        self.num_cached_tokens = 0
132

133
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
134

135
136
137
138
        # Stream Interval
        self.stream_interval = stream_interval
        self.sent_tokens_offset = 0  # Offset of sent tokens

139
140
141
142
143
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
144
145
        prompt: str | None,
        parent_req: ParentRequest | None,
146
        request_index: int,
147
        queue: RequestOutputCollector | None,
148
        log_stats: bool,
149
        stream_interval: int,
150
    ) -> "RequestState":
151
152
153
154
155
156
157
158
159
160
161
162
163
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
164
165
166
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
167
168
169
170
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
171
172
173
            top_p = None
            n = None
            temperature = None
174
175
176
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

177
178
        return cls(
            request_id=request.request_id,
179
180
            parent_req=parent_req,
            request_index=request_index,
181
182
183
            lora_name=(
                request.lora_request.name if request.lora_request is not None else None
            ),
184
            output_kind=output_kind,
185
            prompt=prompt,
186
            prompt_token_ids=request.prompt_token_ids,
187
            prompt_embeds=request.prompt_embeds,
188
189
190
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
191
192
193
            top_p=top_p,
            n=n,
            temperature=temperature,
194
            arrival_time=request.arrival_time,
195
            queue=queue,
196
            log_stats=log_stats,
197
            stream_interval=stream_interval,
198
199
        )

200
201
202
    def make_request_output(
        self,
        new_token_ids: list[int],
203
204
205
206
207
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput | None:
208
        finished = finish_reason is not None
209
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
210

211
        if not finished and final_only:
212
213
214
            # Only the final output is required in FINAL_ONLY mode.
            return None

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        if self.stream_interval > 1:
            assert self.detokenizer is not None

            # Send output request only when
            # 1. It has finished, or
            # 2. It is the first token, or
            # 3. It has reached the stream interval number of tokens
            if not (
                finished
                or self.sent_tokens_offset == 0
                or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
                >= self.stream_interval
            ):
                return None

            if self.output_kind == RequestOutputKind.DELTA:
                # Send tokens from the offset in DELTA mode, otherwise all
                # tokens are sent.
                new_token_ids = self.detokenizer.output_token_ids[
                    self.sent_tokens_offset :
                ]
                self.sent_tokens_offset = len(self.detokenizer.output_token_ids)

238
        request_id = self.request_id
239
240
        if pooling_output is not None:
            return self._new_request_output(
241
242
                request_id, [self._new_pooling_output(pooling_output)], finished
            )
243

244
        output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
245

246
        if self.parent_req is None:
247
            outputs = [output]
248
249
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
250
251
                request_id, output
            )
252
253
            if not outputs:
                return None
254

255
256
257
        return self._new_request_output(
            request_id, outputs, finished, kv_transfer_params
        )
258
259
260
261

    def _new_request_output(
        self,
        request_id: str,
262
        outputs: list[CompletionOutput] | list[PoolingOutput],
263
        finished: bool,
264
265
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
266
267
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
268
            assert len(outputs) == 1
269
270
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
271
272
            return PoolingRequestOutput(
                request_id=request_id,
273
                outputs=first_output,
274
                num_cached_tokens=self.num_cached_tokens,
275
276
277
278
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
279
280
281
282
283
284
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

285
286
287
288
289
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

290
291
292
293
294
295
296
297
298
299
300
        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
301
302
303
304

    def _new_completion_output(
        self,
        token_ids: list[int],
305
306
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
307
    ) -> CompletionOutput:
308
309
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
310
311
312
313
314
315
316
317
318
319
320
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
321
            logprobs = logprobs[-len(token_ids) :]
322
323
324
325
326
327
328
329

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
330
331
            stop_reason=stop_reason if finished else None,
        )
332

333
334
335
336
337
338
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:
        return PoolingOutput(data=pooling_output)

339
340
341
342

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

343
344
345
    def __init__(
        self, tokenizer: AnyTokenizer, log_stats: bool, stream_interval: int = 1
    ):
346
347
        self.log_stats = log_stats
        self.tokenizer = tokenizer
348
        self.stream_interval = stream_interval
349
        self.request_states: dict[str, RequestState] = {}
350
        self.parent_requests: dict[str, ParentRequest] = {}
351
        self.lora_states = LoRARequestStates(log_stats)
352
        self.tracer: Tracer | None = None
353
354
        self._requests_drained = asyncio.Event()
        self._requests_drained.set()
355
356
357
358
359
360
361

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

362
363
364
365
366
    async def wait_for_requests_to_drain(self) -> None:
        if not self.request_states:
            return
        await self._requests_drained.wait()

367
368
369
370
371
372
373
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

374
375
    def abort_requests(
        self,
376
377
378
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
379
        for request_id in request_ids:
380
381
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
382
                self.lora_states.request_finished(request_id, req_state.lora_name)
383
                request_ids_to_abort.append(request_id)
384
385
                # Produce final abort output.
                if req_state.queue is not None and (
386
387
388
389
390
391
392
393
394
395
396
397
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
                        pooling_output=torch.randn(0, device="cpu")
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
398
399
400
401
402
403
404
405
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
406
407
        if not self.request_states:
            self._requests_drained.set()
408
        return request_ids_to_abort
409
410
411
412

    def add_request(
        self,
        request: EngineCoreRequest,
413
414
        prompt: str | None,
        parent_req: ParentRequest | None = None,
415
        request_index: int = 0,
416
        queue: RequestOutputCollector | None = None,
417
418
419
420
421
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

422
423
424
425
426
427
428
429
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
430
            stream_interval=self.stream_interval,
431
        )
432
433
        if self._requests_drained.is_set():
            self._requests_drained.clear()
434
        self.request_states[request_id] = req_state
435
436
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
437
438
439

    def process_outputs(
        self,
440
        engine_core_outputs: list[EngineCoreOutput],
441
442
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
443
444
445
446
447
448
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
449
            * If there is a queue (for usage with AsyncLLM),
450
451
452
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

453
            * If there is no queue (for usage with LLMEngine),
454
455
              return a list of RequestOutput objects.

456
        NOTE FOR DEVELOPERS
457

458
        vLLM V1 minimizes the number of python loops over the full
459
        batch to ensure system overheads are minimized. This is the
460
461
        only function that should loop over EngineCoreOutputs.

462
463
        If you need to touch every element of the batch, do it from
        within the loop below.
464
465
        """

466
        request_outputs: list[RequestOutput | PoolingRequestOutput] = []
467
        reqs_to_abort: list[str] = []
468
469
470
471
472
473
474
475
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
476
477
478
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
479

480
            new_token_ids = engine_core_output.new_token_ids
481
            pooling_output = engine_core_output.pooling_output
482
            finish_reason = engine_core_output.finish_reason
483
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
484
            kv_transfer_params = engine_core_output.kv_transfer_params
485
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
486
            req_state.is_prefilling = False
487

488
489
490
491
492
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
493
494
                    new_token_ids, finish_reason == FinishReason.STOP
                )
495
496
497
498
499
500
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
501
                req_state.logprobs_processor.update_from_output(engine_core_output)
502
503

            # 4) Create and handle RequestOutput objects.
504
            if request_output := req_state.make_request_output(
505
506
507
508
509
510
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
            ):
511
512
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
513
                    req_state.queue.put(request_output)
514
515
516
517
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

518
519
520
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
521
522
523
524
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
525
526
                if not self.request_states:
                    self._requests_drained.set()
527
528
529
530
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
531

532
                # Track per-request stats
533
534
535
                self._update_stats_from_finished(
                    req_state, finish_reason, iteration_stats
                )
536
                if self.tracer:
537
                    self.do_tracing(engine_core_output, req_state, iteration_stats)
538

539
540
541
542
543
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

544
545
546
    def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
        self.lora_states.update_scheduler_stats(scheduler_stats)

547
548
549
550
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
551
        iteration_stats: IterationStats | None,
552
    ) -> None:
553
554
555
556
557
558
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
559
        prompt_length = length_from_prompt_token_ids_or_embeds(
560
561
562
563
564
565
566
567
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
568
            metrics = req_state.stats
569
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
570
571
572
573
574
575
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
576
577
                metrics.first_token_latency,
            )
578
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
579
580
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
581
            span.set_attribute(
582
583
584
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
585
            span.set_attribute(
586
587
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
588
            span.set_attribute(
589
590
591
592
593
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
594
595

            # meta
596
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
597
            if req_state.top_p:
598
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
599
            if req_state.max_tokens_param:
600
601
602
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
603
            if req_state.temperature:
604
605
606
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
607
            if req_state.n:
608
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
609

610
611
612
613
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
614
615
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
616
    ):
617
618
619
620
621
        if iteration_stats is None:
            return

        assert engine_core_timestamp is not None
        assert req_state.stats is not None
622
623
624
625
626
627
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
628
629
            self.lora_states,
            req_state.lora_name,
630
631
632
633
634
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
635
636
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
637
    ):
638
639
640
641
642
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
643
644
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
645
            num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
646
647
                req_state.prompt_token_ids, req_state.prompt_embeds
            ),
648
            max_tokens_param=req_state.max_tokens_param,
649
650
            req_stats=req_state.stats,
        )
651
        self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
652
653

        ParentRequest.observe_finished_request(
654
655
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )