output_processor.py 24.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, cast
8

9
10
import torch

11
from vllm.lora.request import LoRARequest
12
13
14
15
16
17
from vllm.outputs import (
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
18
from vllm.sampling_params import RequestOutputKind
19
from vllm.tokenizers import TokenizerLike
20
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
21
from vllm.utils import length_from_prompt_token_ids_or_embeds
22
23
24
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
25
from vllm.v1.engine.parallel_sampling import ParentRequest
26
27
28
29
30
31
from vllm.v1.metrics.stats import (
    IterationStats,
    LoRARequestStates,
    RequestStateStats,
    SchedulerStats,
)
32
33


34
35
36
37
38
39
40
41
42
43
44
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
45
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
46
47
        self.ready = asyncio.Event()

48
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
49
50
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
51
52
            self.output = output
            self.ready.set()
53
54
55
        elif isinstance(self.output, RequestOutput) and isinstance(
            output, RequestOutput
        ):
56
57
58
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
59
60
61
62
        elif isinstance(self.output, PoolingRequestOutput) and isinstance(
            output, PoolingRequestOutput
        ):
            self.output = output
63

64
    async def get(self) -> RequestOutput | PoolingRequestOutput:
65
        """Get operation blocks on put event."""
66
67
68
69
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
70
71
        if isinstance(output, Exception):
            raise output
72
73
        return output

74
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
75
        """Non-blocking get operation."""
76
77
78
79
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
80
81
        if isinstance(output, Exception):
            raise output
82
83
84
        return output


85
86
@dataclass
class OutputProcessorOutput:
87
    request_outputs: list[RequestOutput | PoolingRequestOutput]
88
    reqs_to_abort: list[str]
89
90
91
92
93
94


class RequestState:
    def __init__(
        self,
        request_id: str,
95
        parent_req: ParentRequest | None,
96
        request_index: int,
97
        lora_request: LoRARequest | None,
98
        output_kind: RequestOutputKind,
99
100
101
102
103
104
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
105
        arrival_time: float,
106
        queue: RequestOutputCollector | None,
107
        log_stats: bool,
108
        stream_interval: int,
109
110
111
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
112
113
    ):
        self.request_id = request_id
114
115
        self.parent_req = parent_req
        self.request_index = request_index
116
117
        self.lora_request = lora_request
        self.lora_name = lora_request.lora_name if lora_request is not None else None
118
        self.output_kind = output_kind
119
120
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
121
122
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
123
124
            self.prompt_token_ids, self.prompt_embeds
        )
125
        self.logprobs_processor = logprobs_processor
126
        self.detokenizer = detokenizer
127
        self.max_tokens_param = max_tokens_param
128
129
130
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
131
132
        self.is_prefilling = True
        self.queue = queue
133
        self.num_cached_tokens = 0
134

135
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
136

137
138
139
140
        # Stream Interval
        self.stream_interval = stream_interval
        self.sent_tokens_offset = 0  # Offset of sent tokens

141
142
143
    @classmethod
    def from_new_request(
        cls,
144
        tokenizer: TokenizerLike | None,
145
        request: EngineCoreRequest,
146
147
        prompt: str | None,
        parent_req: ParentRequest | None,
148
        request_index: int,
149
        queue: RequestOutputCollector | None,
150
        log_stats: bool,
151
        stream_interval: int,
152
    ) -> "RequestState":
153
154
155
156
157
158
159
160
161
162
163
164
165
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
166
167
168
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
169
170
171
172
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
173
174
175
            top_p = None
            n = None
            temperature = None
176
177
178
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

179
180
        return cls(
            request_id=request.request_id,
181
182
            parent_req=parent_req,
            request_index=request_index,
183
            lora_request=request.lora_request,
184
            output_kind=output_kind,
185
            prompt=prompt,
186
            prompt_token_ids=request.prompt_token_ids,
187
            prompt_embeds=request.prompt_embeds,
188
189
190
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
191
192
193
            top_p=top_p,
            n=n,
            temperature=temperature,
194
            arrival_time=request.arrival_time,
195
            queue=queue,
196
            log_stats=log_stats,
197
            stream_interval=stream_interval,
198
199
        )

200
201
202
    def make_request_output(
        self,
        new_token_ids: list[int],
203
204
205
206
207
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput | None:
208
        finished = finish_reason is not None
209
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
210

211
        if not finished and final_only:
212
213
214
            # Only the final output is required in FINAL_ONLY mode.
            return None

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        if self.stream_interval > 1:
            assert self.detokenizer is not None

            # Send output request only when
            # 1. It has finished, or
            # 2. It is the first token, or
            # 3. It has reached the stream interval number of tokens
            if not (
                finished
                or self.sent_tokens_offset == 0
                or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
                >= self.stream_interval
            ):
                return None

            if self.output_kind == RequestOutputKind.DELTA:
                # Send tokens from the offset in DELTA mode, otherwise all
                # tokens are sent.
                new_token_ids = self.detokenizer.output_token_ids[
                    self.sent_tokens_offset :
                ]
                self.sent_tokens_offset = len(self.detokenizer.output_token_ids)

238
        request_id = self.request_id
239
240
        if pooling_output is not None:
            return self._new_request_output(
241
242
                request_id, [self._new_pooling_output(pooling_output)], finished
            )
243

244
        output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
245

246
        if self.parent_req is None:
247
            outputs = [output]
248
249
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
250
251
                request_id, output
            )
252
253
            if not outputs:
                return None
254

255
256
257
        return self._new_request_output(
            request_id, outputs, finished, kv_transfer_params
        )
258
259
260
261

    def _new_request_output(
        self,
        request_id: str,
262
        outputs: list[CompletionOutput] | list[PoolingOutput],
263
        finished: bool,
264
265
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
266
267
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
268
            assert len(outputs) == 1
269
270
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
271
272
            return PoolingRequestOutput(
                request_id=request_id,
273
                outputs=first_output,
274
                num_cached_tokens=self.num_cached_tokens,
275
276
277
278
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
279
280
281
282
283
284
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

285
286
287
288
289
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

290
291
        return RequestOutput(
            request_id=request_id,
292
            lora_request=self.lora_request,
293
294
295
296
297
298
299
300
301
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
302
303
304
305

    def _new_completion_output(
        self,
        token_ids: list[int],
306
307
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
308
    ) -> CompletionOutput:
309
310
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
311
312
313
314
315
316
317
318
319
320
321
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
322
            logprobs = logprobs[-len(token_ids) :]
323
324
325
326
327
328
329
330

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
331
332
            stop_reason=stop_reason if finished else None,
        )
333

334
335
336
337
338
339
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:
        return PoolingOutput(data=pooling_output)

340
341
342
343

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

344
    def __init__(
345
346
347
348
        self,
        tokenizer: TokenizerLike | None,
        log_stats: bool,
        stream_interval: int = 1,
349
    ):
350
351
        self.log_stats = log_stats
        self.tokenizer = tokenizer
352
        self.stream_interval = stream_interval
353
        self.request_states: dict[str, RequestState] = {}
354
        self.parent_requests: dict[str, ParentRequest] = {}
355
        self.lora_states = LoRARequestStates(log_stats)
356
        self.tracer: Tracer | None = None
357
358
        self._requests_drained = asyncio.Event()
        self._requests_drained.set()
359
360
361
362
363
364
365

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

366
367
368
369
370
    async def wait_for_requests_to_drain(self) -> None:
        if not self.request_states:
            return
        await self._requests_drained.wait()

371
372
373
374
375
376
377
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

378
379
    def abort_requests(
        self,
380
381
382
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
383
        for request_id in request_ids:
384
385
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
386
                self.lora_states.request_finished(request_id, req_state.lora_name)
387
                request_ids_to_abort.append(request_id)
388
389
                # Produce final abort output.
                if req_state.queue is not None and (
390
391
392
393
394
395
396
397
398
399
400
401
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
                        pooling_output=torch.randn(0, device="cpu")
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
402
403
404
405
406
407
408
409
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
410
411
        if not self.request_states:
            self._requests_drained.set()
412
        return request_ids_to_abort
413
414
415
416

    def add_request(
        self,
        request: EngineCoreRequest,
417
418
        prompt: str | None,
        parent_req: ParentRequest | None = None,
419
        request_index: int = 0,
420
        queue: RequestOutputCollector | None = None,
421
422
423
424
425
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

426
427
428
429
430
431
432
433
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
434
            stream_interval=self.stream_interval,
435
        )
436
437
        if self._requests_drained.is_set():
            self._requests_drained.clear()
438
        self.request_states[request_id] = req_state
439
440
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
441
442
443

    def process_outputs(
        self,
444
        engine_core_outputs: list[EngineCoreOutput],
445
446
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
447
448
449
450
451
452
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
453
            * If there is a queue (for usage with AsyncLLM),
454
455
456
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

457
            * If there is no queue (for usage with LLMEngine),
458
459
              return a list of RequestOutput objects.

460
        NOTE FOR DEVELOPERS
461

462
        vLLM V1 minimizes the number of python loops over the full
463
        batch to ensure system overheads are minimized. This is the
464
465
        only function that should loop over EngineCoreOutputs.

466
467
        If you need to touch every element of the batch, do it from
        within the loop below.
468
469
        """

470
        request_outputs: list[RequestOutput | PoolingRequestOutput] = []
471
        reqs_to_abort: list[str] = []
472
473
474
475
476
477
478
479
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
480
481
482
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
483

484
            new_token_ids = engine_core_output.new_token_ids
485
            pooling_output = engine_core_output.pooling_output
486
            finish_reason = engine_core_output.finish_reason
487
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
488
            kv_transfer_params = engine_core_output.kv_transfer_params
489
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
490
            req_state.is_prefilling = False
491

492
493
494
495
496
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
497
498
                    new_token_ids, finish_reason == FinishReason.STOP
                )
499
500
501
502
503
504
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
505
                req_state.logprobs_processor.update_from_output(engine_core_output)
506
507

            # 4) Create and handle RequestOutput objects.
508
            if request_output := req_state.make_request_output(
509
510
511
512
513
514
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
            ):
515
516
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
517
                    req_state.queue.put(request_output)
518
519
520
521
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

522
523
524
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
525
526
527
528
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
529
530
                if not self.request_states:
                    self._requests_drained.set()
531
532
533
534
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
535

536
                # Track per-request stats
537
538
539
                self._update_stats_from_finished(
                    req_state, finish_reason, iteration_stats
                )
540
                if self.tracer:
541
                    self.do_tracing(engine_core_output, req_state, iteration_stats)
542

543
544
545
546
547
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

548
549
550
    def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
        self.lora_states.update_scheduler_stats(scheduler_stats)

551
552
553
554
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
555
        iteration_stats: IterationStats | None,
556
    ) -> None:
557
558
559
560
561
562
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
563
        prompt_length = length_from_prompt_token_ids_or_embeds(
564
565
566
567
568
569
570
571
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
572
            metrics = req_state.stats
573
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
574
575
576
577
578
579
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
580
581
                metrics.first_token_latency,
            )
582
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
583
584
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
585
            span.set_attribute(
586
587
588
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
589
            span.set_attribute(
590
591
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
592
            span.set_attribute(
593
594
595
596
597
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
598
599

            # meta
600
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
601
            if req_state.top_p:
602
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
603
            if req_state.max_tokens_param:
604
605
606
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
607
            if req_state.temperature:
608
609
610
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
611
            if req_state.n:
612
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
613

614
615
616
617
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
618
619
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
620
    ):
621
622
623
624
625
        if iteration_stats is None:
            return

        assert engine_core_timestamp is not None
        assert req_state.stats is not None
626
627
628
629
630
631
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
632
633
            self.lora_states,
            req_state.lora_name,
634
635
636
637
638
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
639
640
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
641
    ):
642
643
644
645
646
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
647
648
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
649
            num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
650
651
                req_state.prompt_token_ids, req_state.prompt_embeds
            ),
652
            max_tokens_param=req_state.max_tokens_param,
653
            req_stats=req_state.stats,
654
            num_cached_tokens=req_state.num_cached_tokens,
655
        )
656
        self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
657
658

        ParentRequest.observe_finished_request(
659
660
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )