output_processor.py 22.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, cast
8

9
10
import torch

11
12
13
14
15
16
from vllm.outputs import (
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
19
from vllm.transformers_utils.tokenizer import AnyTokenizer
20
from vllm.utils import length_from_prompt_token_ids_or_embeds
21
22
23
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
24
from vllm.v1.engine.parallel_sampling import ParentRequest
25
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats
26
27


28
29
30
31
32
33
34
35
36
37
38
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
39
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
40
41
        self.ready = asyncio.Event()

42
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
43
44
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
45
46
            self.output = output
            self.ready.set()
47
        elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
48
49
50
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
51

52
    async def get(self) -> RequestOutput | PoolingRequestOutput:
53
        """Get operation blocks on put event."""
54
55
56
57
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
58
59
        if isinstance(output, Exception):
            raise output
60
61
        return output

62
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
63
        """Non-blocking get operation."""
64
65
66
67
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
68
69
        if isinstance(output, Exception):
            raise output
70
71
72
        return output


73
74
@dataclass
class OutputProcessorOutput:
75
    request_outputs: list[RequestOutput | PoolingRequestOutput]
76
    reqs_to_abort: list[str]
77
78
79
80
81
82


class RequestState:
    def __init__(
        self,
        request_id: str,
83
        parent_req: ParentRequest | None,
84
        request_index: int,
85
        lora_name: str | None,
86
        output_kind: RequestOutputKind,
87
88
89
90
91
92
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
93
        arrival_time: float,
94
        queue: RequestOutputCollector | None,
95
        log_stats: bool,
96
97
98
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
99
100
    ):
        self.request_id = request_id
101
102
        self.parent_req = parent_req
        self.request_index = request_index
103
        self.lora_name = lora_name
104
        self.output_kind = output_kind
105
106
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
107
108
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
109
110
            self.prompt_token_ids, self.prompt_embeds
        )
111
        self.logprobs_processor = logprobs_processor
112
        self.detokenizer = detokenizer
113
        self.max_tokens_param = max_tokens_param
114
115
116
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
117
118
        self.is_prefilling = True
        self.queue = queue
119
        self.num_cached_tokens = 0
120

121
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
122

123
124
125
126
127
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
128
129
        prompt: str | None,
        parent_req: ParentRequest | None,
130
        request_index: int,
131
        queue: RequestOutputCollector | None,
132
        log_stats: bool,
133
    ) -> "RequestState":
134
135
136
137
138
139
140
141
142
143
144
145
146
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
147
148
149
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
150
151
152
153
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
154
155
156
            top_p = None
            n = None
            temperature = None
157
158
159
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

160
161
        return cls(
            request_id=request.request_id,
162
163
            parent_req=parent_req,
            request_index=request_index,
164
165
166
            lora_name=(
                request.lora_request.name if request.lora_request is not None else None
            ),
167
            output_kind=output_kind,
168
            prompt=prompt,
169
            prompt_token_ids=request.prompt_token_ids,
170
            prompt_embeds=request.prompt_embeds,
171
172
173
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
174
175
176
            top_p=top_p,
            n=n,
            temperature=temperature,
177
            arrival_time=request.arrival_time,
178
            queue=queue,
179
            log_stats=log_stats,
180
181
        )

182
183
184
    def make_request_output(
        self,
        new_token_ids: list[int],
185
186
187
188
189
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput | None:
190
        finished = finish_reason is not None
191
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
192

193
        if not finished and final_only:
194
195
196
            # Only the final output is required in FINAL_ONLY mode.
            return None

197
        request_id = self.request_id
198
199
        if pooling_output is not None:
            return self._new_request_output(
200
201
                request_id, [self._new_pooling_output(pooling_output)], finished
            )
202

203
        output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
204

205
        if self.parent_req is None:
206
            outputs = [output]
207
208
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
209
210
                request_id, output
            )
211
212
            if not outputs:
                return None
213

214
215
216
        return self._new_request_output(
            request_id, outputs, finished, kv_transfer_params
        )
217
218
219
220

    def _new_request_output(
        self,
        request_id: str,
221
        outputs: list[CompletionOutput] | list[PoolingOutput],
222
        finished: bool,
223
224
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
225
226
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
227
            assert len(outputs) == 1
228
229
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
230
231
            return PoolingRequestOutput(
                request_id=request_id,
232
                outputs=first_output,
233
                num_cached_tokens=self.num_cached_tokens,
234
235
236
237
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
238
239
240
241
242
243
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

244
245
246
247
248
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

249
250
251
252
253
254
255
256
257
258
259
        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
260
261
262
263

    def _new_completion_output(
        self,
        token_ids: list[int],
264
265
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
266
    ) -> CompletionOutput:
267
268
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
269
270
271
272
273
274
275
276
277
278
279
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
280
            logprobs = logprobs[-len(token_ids) :]
281
282
283
284
285
286
287
288

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
289
290
            stop_reason=stop_reason if finished else None,
        )
291

292
293
294
295
296
297
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:
        return PoolingOutput(data=pooling_output)

298
299
300
301

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

302
    def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
303
304
        self.log_stats = log_stats
        self.tokenizer = tokenizer
305
        self.request_states: dict[str, RequestState] = {}
306
        self.parent_requests: dict[str, ParentRequest] = {}
307
        self.lora_states = LoRARequestStates()
308
        self.tracer: Tracer | None = None
309
310
311
312
313
314
315

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

316
317
318
319
320
321
322
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

323
324
    def abort_requests(
        self,
325
326
327
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
328
        for request_id in request_ids:
329
330
331
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
                self.lora_states.abort_request(req_state)
332
                request_ids_to_abort.append(request_id)
333
334
                # Produce final abort output.
                if req_state.queue is not None and (
335
336
337
338
339
340
341
342
343
344
345
346
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
                        pooling_output=torch.randn(0, device="cpu")
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
347
348
349
350
351
352
353
354
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
355
        return request_ids_to_abort
356
357
358
359

    def add_request(
        self,
        request: EngineCoreRequest,
360
361
        prompt: str | None,
        parent_req: ParentRequest | None = None,
362
        request_index: int = 0,
363
        queue: RequestOutputCollector | None = None,
364
365
366
367
368
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

369
370
371
372
373
374
375
376
377
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
        )
378
379
        self.request_states[request_id] = req_state
        self.lora_states.add_request(req_state)
380
381
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
382
383
384

    def process_outputs(
        self,
385
        engine_core_outputs: list[EngineCoreOutput],
386
387
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
388
389
390
391
392
393
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
394
            * If there is a queue (for usage with AsyncLLM),
395
396
397
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

398
            * If there is no queue (for usage with LLMEngine),
399
400
              return a list of RequestOutput objects.

401
        NOTE FOR DEVELOPERS
402

403
        vLLM V1 minimizes the number of python loops over the full
404
        batch to ensure system overheads are minimized. This is the
405
406
        only function that should loop over EngineCoreOutputs.

407
408
        If you need to touch every element of the batch, do it from
        within the loop below.
409
410
        """

411
        request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = []
412
        reqs_to_abort: list[str] = []
413
414
415
416
417
418
419
420
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
421
422
423
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
424

425
            new_token_ids = engine_core_output.new_token_ids
426
            pooling_output = engine_core_output.pooling_output
427
            finish_reason = engine_core_output.finish_reason
428
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
429
            kv_transfer_params = engine_core_output.kv_transfer_params
430
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
431
            req_state.is_prefilling = False
432

433
434
435
436
437
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
438
439
                    new_token_ids, finish_reason == FinishReason.STOP
                )
440
441
442
443
444
445
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
446
                req_state.logprobs_processor.update_from_output(engine_core_output)
447
448

            # 4) Create and handle RequestOutput objects.
449
            if request_output := req_state.make_request_output(
450
451
452
453
454
455
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
            ):
456
457
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
458
                    req_state.queue.put(request_output)
459
460
461
462
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

463
464
465
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
466
467
468
469
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
470
471
472
473
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
474

475
                # Track per-request stats
476
477
478
                self._update_stats_from_finished(
                    req_state, finish_reason, iteration_stats
                )
479
                if self.tracer:
480
                    self.do_tracing(engine_core_output, req_state, iteration_stats)
481
482
        self.lora_states.update_iteration_stats(iteration_stats)

483
484
485
486
487
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

488
489
490
491
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
492
        iteration_stats: IterationStats | None,
493
    ) -> None:
494
495
496
497
498
499
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
500
        prompt_length = length_from_prompt_token_ids_or_embeds(
501
502
503
504
505
506
507
508
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
509
            metrics = req_state.stats
510
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
511
512
513
514
515
516
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
517
518
                metrics.first_token_latency,
            )
519
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
520
521
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
522
            span.set_attribute(
523
524
525
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
526
            span.set_attribute(
527
528
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
529
            span.set_attribute(
530
531
532
533
534
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
535
536

            # meta
537
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
538
            if req_state.top_p:
539
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
540
            if req_state.max_tokens_param:
541
542
543
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
544
            if req_state.temperature:
545
546
547
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
548
            if req_state.n:
549
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
550

551
552
553
554
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
555
556
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
557
    ):
558
559
560
        if iteration_stats is None:
            return

561
562
        lora_stats = self.lora_states.get_stats(req_state)

563
564
        assert engine_core_timestamp is not None
        assert req_state.stats is not None
565
566
567
568
569
570
571
572
573
574
575
576
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
            lora_stats,
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
577
578
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
579
    ):
580
581
582
583
584
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
585
586
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
587
            num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
588
589
                req_state.prompt_token_ids, req_state.prompt_embeds
            ),
590
            max_tokens_param=req_state.max_tokens_param,
591
592
            req_stats=req_state.stats,
        )
593
        self.lora_states.finish_request(req_state)
594
595

        ParentRequest.observe_finished_request(
596
597
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )