output_processor.py 22.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, cast
8

9
10
import torch

11
12
13
14
15
16
from vllm.outputs import (
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
19
from vllm.transformers_utils.tokenizer import AnyTokenizer
20
from vllm.utils import length_from_prompt_token_ids_or_embeds
21
22
23
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
24
from vllm.v1.engine.parallel_sampling import ParentRequest
25
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats
26
27


28
29
30
31
32
33
34
35
36
37
38
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
39
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
40
41
        self.ready = asyncio.Event()

42
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
43
44
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
45
46
            self.output = output
            self.ready.set()
47
48
49
        elif isinstance(self.output, RequestOutput) and isinstance(
            output, RequestOutput
        ):
50
51
52
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
53
54
55
56
        elif isinstance(self.output, PoolingRequestOutput) and isinstance(
            output, PoolingRequestOutput
        ):
            self.output = output
57

58
    async def get(self) -> RequestOutput | PoolingRequestOutput:
59
        """Get operation blocks on put event."""
60
61
62
63
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
64
65
        if isinstance(output, Exception):
            raise output
66
67
        return output

68
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
69
        """Non-blocking get operation."""
70
71
72
73
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
74
75
        if isinstance(output, Exception):
            raise output
76
77
78
        return output


79
80
@dataclass
class OutputProcessorOutput:
81
    request_outputs: list[RequestOutput | PoolingRequestOutput]
82
    reqs_to_abort: list[str]
83
84
85
86
87
88


class RequestState:
    def __init__(
        self,
        request_id: str,
89
        parent_req: ParentRequest | None,
90
        request_index: int,
91
        lora_name: str | None,
92
        output_kind: RequestOutputKind,
93
94
95
96
97
98
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
99
        arrival_time: float,
100
        queue: RequestOutputCollector | None,
101
        log_stats: bool,
102
103
104
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
105
106
    ):
        self.request_id = request_id
107
108
        self.parent_req = parent_req
        self.request_index = request_index
109
        self.lora_name = lora_name
110
        self.output_kind = output_kind
111
112
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
113
114
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
115
116
            self.prompt_token_ids, self.prompt_embeds
        )
117
        self.logprobs_processor = logprobs_processor
118
        self.detokenizer = detokenizer
119
        self.max_tokens_param = max_tokens_param
120
121
122
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
123
124
        self.is_prefilling = True
        self.queue = queue
125
        self.num_cached_tokens = 0
126

127
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
128

129
130
131
132
133
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
134
135
        prompt: str | None,
        parent_req: ParentRequest | None,
136
        request_index: int,
137
        queue: RequestOutputCollector | None,
138
        log_stats: bool,
139
    ) -> "RequestState":
140
141
142
143
144
145
146
147
148
149
150
151
152
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
153
154
155
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
156
157
158
159
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
160
161
162
            top_p = None
            n = None
            temperature = None
163
164
165
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

166
167
        return cls(
            request_id=request.request_id,
168
169
            parent_req=parent_req,
            request_index=request_index,
170
171
172
            lora_name=(
                request.lora_request.name if request.lora_request is not None else None
            ),
173
            output_kind=output_kind,
174
            prompt=prompt,
175
            prompt_token_ids=request.prompt_token_ids,
176
            prompt_embeds=request.prompt_embeds,
177
178
179
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
180
181
182
            top_p=top_p,
            n=n,
            temperature=temperature,
183
            arrival_time=request.arrival_time,
184
            queue=queue,
185
            log_stats=log_stats,
186
187
        )

188
189
190
    def make_request_output(
        self,
        new_token_ids: list[int],
191
192
193
194
195
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput | None:
196
        finished = finish_reason is not None
197
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
198

199
        if not finished and final_only:
200
201
202
            # Only the final output is required in FINAL_ONLY mode.
            return None

203
        request_id = self.request_id
204
205
        if pooling_output is not None:
            return self._new_request_output(
206
207
                request_id, [self._new_pooling_output(pooling_output)], finished
            )
208

209
        output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
210

211
        if self.parent_req is None:
212
            outputs = [output]
213
214
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
215
216
                request_id, output
            )
217
218
            if not outputs:
                return None
219

220
221
222
        return self._new_request_output(
            request_id, outputs, finished, kv_transfer_params
        )
223
224
225
226

    def _new_request_output(
        self,
        request_id: str,
227
        outputs: list[CompletionOutput] | list[PoolingOutput],
228
        finished: bool,
229
230
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
231
232
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
233
            assert len(outputs) == 1
234
235
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
236
237
            return PoolingRequestOutput(
                request_id=request_id,
238
                outputs=first_output,
239
                num_cached_tokens=self.num_cached_tokens,
240
241
242
243
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
244
245
246
247
248
249
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

250
251
252
253
254
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

255
256
257
258
259
260
261
262
263
264
265
        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
266
267
268
269

    def _new_completion_output(
        self,
        token_ids: list[int],
270
271
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
272
    ) -> CompletionOutput:
273
274
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
275
276
277
278
279
280
281
282
283
284
285
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
286
            logprobs = logprobs[-len(token_ids) :]
287
288
289
290
291
292
293
294

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
295
296
            stop_reason=stop_reason if finished else None,
        )
297

298
299
300
301
302
303
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:
        return PoolingOutput(data=pooling_output)

304
305
306
307

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

308
    def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
309
310
        self.log_stats = log_stats
        self.tokenizer = tokenizer
311
        self.request_states: dict[str, RequestState] = {}
312
        self.parent_requests: dict[str, ParentRequest] = {}
313
        self.lora_states = LoRARequestStates()
314
        self.tracer: Tracer | None = None
315
316
317
318
319
320
321

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

322
323
324
325
326
327
328
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

329
330
    def abort_requests(
        self,
331
332
333
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
334
        for request_id in request_ids:
335
336
337
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
                self.lora_states.abort_request(req_state)
338
                request_ids_to_abort.append(request_id)
339
340
                # Produce final abort output.
                if req_state.queue is not None and (
341
342
343
344
345
346
347
348
349
350
351
352
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
                        pooling_output=torch.randn(0, device="cpu")
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
353
354
355
356
357
358
359
360
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
361
        return request_ids_to_abort
362
363
364
365

    def add_request(
        self,
        request: EngineCoreRequest,
366
367
        prompt: str | None,
        parent_req: ParentRequest | None = None,
368
        request_index: int = 0,
369
        queue: RequestOutputCollector | None = None,
370
371
372
373
374
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

375
376
377
378
379
380
381
382
383
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
        )
384
385
        self.request_states[request_id] = req_state
        self.lora_states.add_request(req_state)
386
387
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
388
389
390

    def process_outputs(
        self,
391
        engine_core_outputs: list[EngineCoreOutput],
392
393
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
394
395
396
397
398
399
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
400
            * If there is a queue (for usage with AsyncLLM),
401
402
403
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

404
            * If there is no queue (for usage with LLMEngine),
405
406
              return a list of RequestOutput objects.

407
        NOTE FOR DEVELOPERS
408

409
        vLLM V1 minimizes the number of python loops over the full
410
        batch to ensure system overheads are minimized. This is the
411
412
        only function that should loop over EngineCoreOutputs.

413
414
        If you need to touch every element of the batch, do it from
        within the loop below.
415
416
        """

417
        request_outputs: list[RequestOutput | PoolingRequestOutput] = []
418
        reqs_to_abort: list[str] = []
419
420
421
422
423
424
425
426
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
427
428
429
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
430

431
            new_token_ids = engine_core_output.new_token_ids
432
            pooling_output = engine_core_output.pooling_output
433
            finish_reason = engine_core_output.finish_reason
434
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
435
            kv_transfer_params = engine_core_output.kv_transfer_params
436
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
437
            req_state.is_prefilling = False
438

439
440
441
442
443
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
444
445
                    new_token_ids, finish_reason == FinishReason.STOP
                )
446
447
448
449
450
451
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
452
                req_state.logprobs_processor.update_from_output(engine_core_output)
453
454

            # 4) Create and handle RequestOutput objects.
455
            if request_output := req_state.make_request_output(
456
457
458
459
460
461
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
            ):
462
463
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
464
                    req_state.queue.put(request_output)
465
466
467
468
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

469
470
471
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
472
473
474
475
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
476
477
478
479
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
480

481
                # Track per-request stats
482
483
484
                self._update_stats_from_finished(
                    req_state, finish_reason, iteration_stats
                )
485
                if self.tracer:
486
                    self.do_tracing(engine_core_output, req_state, iteration_stats)
487
488
        self.lora_states.update_iteration_stats(iteration_stats)

489
490
491
492
493
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

494
495
496
497
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
498
        iteration_stats: IterationStats | None,
499
    ) -> None:
500
501
502
503
504
505
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
506
        prompt_length = length_from_prompt_token_ids_or_embeds(
507
508
509
510
511
512
513
514
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
515
            metrics = req_state.stats
516
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
517
518
519
520
521
522
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
523
524
                metrics.first_token_latency,
            )
525
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
526
527
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
528
            span.set_attribute(
529
530
531
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
532
            span.set_attribute(
533
534
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
535
            span.set_attribute(
536
537
538
539
540
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
541
542

            # meta
543
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
544
            if req_state.top_p:
545
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
546
            if req_state.max_tokens_param:
547
548
549
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
550
            if req_state.temperature:
551
552
553
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
554
            if req_state.n:
555
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
556

557
558
559
560
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
561
562
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
563
    ):
564
565
566
        if iteration_stats is None:
            return

567
568
        lora_stats = self.lora_states.get_stats(req_state)

569
570
        assert engine_core_timestamp is not None
        assert req_state.stats is not None
571
572
573
574
575
576
577
578
579
580
581
582
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
            lora_stats,
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
583
584
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
585
    ):
586
587
588
589
590
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
591
592
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
593
            num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
594
595
                req_state.prompt_token_ids, req_state.prompt_embeds
            ),
596
            max_tokens_param=req_state.max_tokens_param,
597
598
            req_stats=req_state.stats,
        )
599
        self.lora_states.finish_request(req_state)
600
601

        ParentRequest.observe_finished_request(
602
603
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )