output_processor.py 22.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, cast
8

9
10
import torch

11
12
13
14
15
16
from vllm.outputs import (
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
19
from vllm.transformers_utils.tokenizer import AnyTokenizer
20
from vllm.utils import length_from_prompt_token_ids_or_embeds
21
22
23
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
24
from vllm.v1.engine.parallel_sampling import ParentRequest
25
26
27
28
29
30
from vllm.v1.metrics.stats import (
    IterationStats,
    LoRARequestStates,
    RequestStateStats,
    SchedulerStats,
)
31
32


33
34
35
36
37
38
39
40
41
42
43
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
44
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
45
46
        self.ready = asyncio.Event()

47
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
48
49
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
50
51
            self.output = output
            self.ready.set()
52
53
54
        elif isinstance(self.output, RequestOutput) and isinstance(
            output, RequestOutput
        ):
55
56
57
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
58
59
60
61
        elif isinstance(self.output, PoolingRequestOutput) and isinstance(
            output, PoolingRequestOutput
        ):
            self.output = output
62

63
    async def get(self) -> RequestOutput | PoolingRequestOutput:
64
        """Get operation blocks on put event."""
65
66
67
68
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
69
70
        if isinstance(output, Exception):
            raise output
71
72
        return output

73
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
74
        """Non-blocking get operation."""
75
76
77
78
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
79
80
        if isinstance(output, Exception):
            raise output
81
82
83
        return output


84
85
@dataclass
class OutputProcessorOutput:
86
    request_outputs: list[RequestOutput | PoolingRequestOutput]
87
    reqs_to_abort: list[str]
88
89
90
91
92
93


class RequestState:
    def __init__(
        self,
        request_id: str,
94
        parent_req: ParentRequest | None,
95
        request_index: int,
96
        lora_name: str | None,
97
        output_kind: RequestOutputKind,
98
99
100
101
102
103
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
104
        arrival_time: float,
105
        queue: RequestOutputCollector | None,
106
        log_stats: bool,
107
108
109
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
110
111
    ):
        self.request_id = request_id
112
113
        self.parent_req = parent_req
        self.request_index = request_index
114
        self.lora_name = lora_name
115
        self.output_kind = output_kind
116
117
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
118
119
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
120
121
            self.prompt_token_ids, self.prompt_embeds
        )
122
        self.logprobs_processor = logprobs_processor
123
        self.detokenizer = detokenizer
124
        self.max_tokens_param = max_tokens_param
125
126
127
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
128
129
        self.is_prefilling = True
        self.queue = queue
130
        self.num_cached_tokens = 0
131

132
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
133

134
135
136
137
138
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
139
140
        prompt: str | None,
        parent_req: ParentRequest | None,
141
        request_index: int,
142
        queue: RequestOutputCollector | None,
143
        log_stats: bool,
144
    ) -> "RequestState":
145
146
147
148
149
150
151
152
153
154
155
156
157
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
158
159
160
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
161
162
163
164
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
165
166
167
            top_p = None
            n = None
            temperature = None
168
169
170
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

171
172
        return cls(
            request_id=request.request_id,
173
174
            parent_req=parent_req,
            request_index=request_index,
175
176
177
            lora_name=(
                request.lora_request.name if request.lora_request is not None else None
            ),
178
            output_kind=output_kind,
179
            prompt=prompt,
180
            prompt_token_ids=request.prompt_token_ids,
181
            prompt_embeds=request.prompt_embeds,
182
183
184
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
185
186
187
            top_p=top_p,
            n=n,
            temperature=temperature,
188
            arrival_time=request.arrival_time,
189
            queue=queue,
190
            log_stats=log_stats,
191
192
        )

193
194
195
    def make_request_output(
        self,
        new_token_ids: list[int],
196
197
198
199
200
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput | None:
201
        finished = finish_reason is not None
202
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
203

204
        if not finished and final_only:
205
206
207
            # Only the final output is required in FINAL_ONLY mode.
            return None

208
        request_id = self.request_id
209
210
        if pooling_output is not None:
            return self._new_request_output(
211
212
                request_id, [self._new_pooling_output(pooling_output)], finished
            )
213

214
        output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
215

216
        if self.parent_req is None:
217
            outputs = [output]
218
219
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
220
221
                request_id, output
            )
222
223
            if not outputs:
                return None
224

225
226
227
        return self._new_request_output(
            request_id, outputs, finished, kv_transfer_params
        )
228
229
230
231

    def _new_request_output(
        self,
        request_id: str,
232
        outputs: list[CompletionOutput] | list[PoolingOutput],
233
        finished: bool,
234
235
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
236
237
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
238
            assert len(outputs) == 1
239
240
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
241
242
            return PoolingRequestOutput(
                request_id=request_id,
243
                outputs=first_output,
244
                num_cached_tokens=self.num_cached_tokens,
245
246
247
248
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
249
250
251
252
253
254
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

255
256
257
258
259
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

260
261
262
263
264
265
266
267
268
269
270
        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
271
272
273
274

    def _new_completion_output(
        self,
        token_ids: list[int],
275
276
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
277
    ) -> CompletionOutput:
278
279
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
280
281
282
283
284
285
286
287
288
289
290
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
291
            logprobs = logprobs[-len(token_ids) :]
292
293
294
295
296
297
298
299

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
300
301
            stop_reason=stop_reason if finished else None,
        )
302

303
304
305
306
307
308
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:
        return PoolingOutput(data=pooling_output)

309
310
311
312

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

313
    def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
314
315
        self.log_stats = log_stats
        self.tokenizer = tokenizer
316
        self.request_states: dict[str, RequestState] = {}
317
        self.parent_requests: dict[str, ParentRequest] = {}
318
        self.lora_states = LoRARequestStates(log_stats)
319
        self.tracer: Tracer | None = None
320
321
322
323
324
325
326

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

327
328
329
330
331
332
333
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

334
335
    def abort_requests(
        self,
336
337
338
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
339
        for request_id in request_ids:
340
341
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
342
                self.lora_states.request_finished(request_id, req_state.lora_name)
343
                request_ids_to_abort.append(request_id)
344
345
                # Produce final abort output.
                if req_state.queue is not None and (
346
347
348
349
350
351
352
353
354
355
356
357
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
                        pooling_output=torch.randn(0, device="cpu")
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
358
359
360
361
362
363
364
365
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
366
        return request_ids_to_abort
367
368
369
370

    def add_request(
        self,
        request: EngineCoreRequest,
371
372
        prompt: str | None,
        parent_req: ParentRequest | None = None,
373
        request_index: int = 0,
374
        queue: RequestOutputCollector | None = None,
375
376
377
378
379
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

380
381
382
383
384
385
386
387
388
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
        )
389
        self.request_states[request_id] = req_state
390
391
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
392
393
394

    def process_outputs(
        self,
395
        engine_core_outputs: list[EngineCoreOutput],
396
397
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
398
399
400
401
402
403
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
404
            * If there is a queue (for usage with AsyncLLM),
405
406
407
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

408
            * If there is no queue (for usage with LLMEngine),
409
410
              return a list of RequestOutput objects.

411
        NOTE FOR DEVELOPERS
412

413
        vLLM V1 minimizes the number of python loops over the full
414
        batch to ensure system overheads are minimized. This is the
415
416
        only function that should loop over EngineCoreOutputs.

417
418
        If you need to touch every element of the batch, do it from
        within the loop below.
419
420
        """

421
        request_outputs: list[RequestOutput | PoolingRequestOutput] = []
422
        reqs_to_abort: list[str] = []
423
424
425
426
427
428
429
430
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
431
432
433
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
434

435
            new_token_ids = engine_core_output.new_token_ids
436
            pooling_output = engine_core_output.pooling_output
437
            finish_reason = engine_core_output.finish_reason
438
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
439
            kv_transfer_params = engine_core_output.kv_transfer_params
440
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
441
            req_state.is_prefilling = False
442

443
444
445
446
447
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
448
449
                    new_token_ids, finish_reason == FinishReason.STOP
                )
450
451
452
453
454
455
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
456
                req_state.logprobs_processor.update_from_output(engine_core_output)
457
458

            # 4) Create and handle RequestOutput objects.
459
            if request_output := req_state.make_request_output(
460
461
462
463
464
465
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
            ):
466
467
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
468
                    req_state.queue.put(request_output)
469
470
471
472
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

473
474
475
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
476
477
478
479
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
480
481
482
483
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
484

485
                # Track per-request stats
486
487
488
                self._update_stats_from_finished(
                    req_state, finish_reason, iteration_stats
                )
489
                if self.tracer:
490
                    self.do_tracing(engine_core_output, req_state, iteration_stats)
491

492
493
494
495
496
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

497
498
499
    def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
        self.lora_states.update_scheduler_stats(scheduler_stats)

500
501
502
503
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
504
        iteration_stats: IterationStats | None,
505
    ) -> None:
506
507
508
509
510
511
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
512
        prompt_length = length_from_prompt_token_ids_or_embeds(
513
514
515
516
517
518
519
520
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
521
            metrics = req_state.stats
522
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
523
524
525
526
527
528
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
529
530
                metrics.first_token_latency,
            )
531
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
532
533
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
534
            span.set_attribute(
535
536
537
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
538
            span.set_attribute(
539
540
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
541
            span.set_attribute(
542
543
544
545
546
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
547
548

            # meta
549
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
550
            if req_state.top_p:
551
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
552
            if req_state.max_tokens_param:
553
554
555
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
556
            if req_state.temperature:
557
558
559
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
560
            if req_state.n:
561
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
562

563
564
565
566
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
567
568
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
569
    ):
570
571
572
573
574
        if iteration_stats is None:
            return

        assert engine_core_timestamp is not None
        assert req_state.stats is not None
575
576
577
578
579
580
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
581
582
            self.lora_states,
            req_state.lora_name,
583
584
585
586
587
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
588
589
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
590
    ):
591
592
593
594
595
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
596
597
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
598
            num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
599
600
                req_state.prompt_token_ids, req_state.prompt_embeds
            ),
601
            max_tokens_param=req_state.max_tokens_param,
602
603
            req_stats=req_state.stats,
        )
604
        self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
605
606

        ParentRequest.observe_finished_request(
607
608
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )