output_processor.py 22.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, cast
8

9
10
import torch

11
12
13
14
15
16
from vllm.outputs import (
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
19
from vllm.transformers_utils.tokenizer import AnyTokenizer
20
from vllm.utils import length_from_prompt_token_ids_or_embeds
21
22
23
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
24
from vllm.v1.engine.parallel_sampling import ParentRequest
25
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats
26
27


28
29
30
31
32
33
34
35
36
37
38
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
39
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
40
41
        self.ready = asyncio.Event()

42
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
43
44
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
45
46
            self.output = output
            self.ready.set()
47
        elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
48
49
50
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
51

52
    async def get(self) -> RequestOutput | PoolingRequestOutput:
53
        """Get operation blocks on put event."""
54
55
56
57
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
58
59
        if isinstance(output, Exception):
            raise output
60
61
        return output

62
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
63
        """Non-blocking get operation."""
64
65
66
67
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
68
69
        if isinstance(output, Exception):
            raise output
70
71
72
        return output


73
74
@dataclass
class OutputProcessorOutput:
75
    request_outputs: list[RequestOutput | PoolingRequestOutput]
76
    reqs_to_abort: list[str]
77
78
79
80
81
82


class RequestState:
    def __init__(
        self,
        request_id: str,
83
        parent_req: ParentRequest | None,
84
        request_index: int,
85
        lora_name: str | None,
86
        output_kind: RequestOutputKind,
87
88
89
90
91
92
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
93
        arrival_time: float,
94
        queue: RequestOutputCollector | None,
95
        log_stats: bool,
96
97
98
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
99
100
    ):
        self.request_id = request_id
101
102
        self.parent_req = parent_req
        self.request_index = request_index
103
        self.lora_name = lora_name
104
        self.output_kind = output_kind
105
106
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
107
108
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
109
110
            self.prompt_token_ids, self.prompt_embeds
        )
111
        self.logprobs_processor = logprobs_processor
112
        self.detokenizer = detokenizer
113
        self.max_tokens_param = max_tokens_param
114
115
116
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
117
118
        self.is_prefilling = True
        self.queue = queue
119
        self.num_cached_tokens = 0
120

121
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
122

123
124
125
126
127
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
128
129
        prompt: str | None,
        parent_req: ParentRequest | None,
130
        request_index: int,
131
        queue: RequestOutputCollector | None,
132
        log_stats: bool,
133
    ) -> "RequestState":
134
135
136
137
138
139
140
141
142
143
144
145
146
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
147
148
149
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
150
151
152
153
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
154
155
156
            top_p = None
            n = None
            temperature = None
157
158
159
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

160
161
        return cls(
            request_id=request.request_id,
162
163
            parent_req=parent_req,
            request_index=request_index,
164
165
166
            lora_name=(
                request.lora_request.name if request.lora_request is not None else None
            ),
167
            output_kind=output_kind,
168
            prompt=prompt,
169
            prompt_token_ids=request.prompt_token_ids,
170
            prompt_embeds=request.prompt_embeds,
171
172
173
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
174
175
176
            top_p=top_p,
            n=n,
            temperature=temperature,
177
            arrival_time=request.arrival_time,
178
            queue=queue,
179
            log_stats=log_stats,
180
181
        )

182
183
184
    def make_request_output(
        self,
        new_token_ids: list[int],
185
186
187
188
189
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput | None:
190
        finished = finish_reason is not None
191
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
192

193
        if not finished and final_only:
194
195
196
            # Only the final output is required in FINAL_ONLY mode.
            return None

197
        request_id = self.request_id
198
199
        if pooling_output is not None:
            return self._new_request_output(
200
201
                request_id, [self._new_pooling_output(pooling_output)], finished
            )
202

203
        output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
204

205
        if self.parent_req is None:
206
            outputs = [output]
207
208
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
209
210
                request_id, output
            )
211
212
            if not outputs:
                return None
213

214
215
216
        return self._new_request_output(
            request_id, outputs, finished, kv_transfer_params
        )
217
218
219
220

    def _new_request_output(
        self,
        request_id: str,
221
        outputs: list[CompletionOutput] | list[PoolingOutput],
222
        finished: bool,
223
224
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
225
226
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
227
            assert len(outputs) == 1
228
229
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
230
231
            return PoolingRequestOutput(
                request_id=request_id,
232
                outputs=first_output,
233
234
235
236
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
237
238
239
240
241
242
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

243
244
245
246
247
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

248
249
250
251
252
253
254
255
256
257
258
        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
259
260
261
262

    def _new_completion_output(
        self,
        token_ids: list[int],
263
264
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
265
    ) -> CompletionOutput:
266
267
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
268
269
270
271
272
273
274
275
276
277
278
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
279
            logprobs = logprobs[-len(token_ids) :]
280
281
282
283
284
285
286
287

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
288
289
            stop_reason=stop_reason if finished else None,
        )
290

291
292
293
294
295
296
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:
        return PoolingOutput(data=pooling_output)

297
298
299
300

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

301
    def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
302
303
        self.log_stats = log_stats
        self.tokenizer = tokenizer
304
        self.request_states: dict[str, RequestState] = {}
305
        self.parent_requests: dict[str, ParentRequest] = {}
306
        self.lora_states = LoRARequestStates()
307
        self.tracer: Tracer | None = None
308
309
310
311
312
313
314

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

315
316
317
318
319
320
321
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

322
323
    def abort_requests(
        self,
324
325
326
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
327
        for request_id in request_ids:
328
329
330
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
                self.lora_states.abort_request(req_state)
331
                request_ids_to_abort.append(request_id)
332
333
                # Produce final abort output.
                if req_state.queue is not None and (
334
335
336
337
338
339
340
341
342
343
344
345
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
                        pooling_output=torch.randn(0, device="cpu")
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
346
347
348
349
350
351
352
353
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
354
        return request_ids_to_abort
355
356
357
358

    def add_request(
        self,
        request: EngineCoreRequest,
359
360
        prompt: str | None,
        parent_req: ParentRequest | None = None,
361
        request_index: int = 0,
362
        queue: RequestOutputCollector | None = None,
363
364
365
366
367
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

368
369
370
371
372
373
374
375
376
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
        )
377
378
        self.request_states[request_id] = req_state
        self.lora_states.add_request(req_state)
379
380
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
381
382
383

    def process_outputs(
        self,
384
        engine_core_outputs: list[EngineCoreOutput],
385
386
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
387
388
389
390
391
392
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
393
            * If there is a queue (for usage with AsyncLLM),
394
395
396
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

397
            * If there is no queue (for usage with LLMEngine),
398
399
              return a list of RequestOutput objects.

400
        NOTE FOR DEVELOPERS
401

402
        vLLM V1 minimizes the number of python loops over the full
403
        batch to ensure system overheads are minimized. This is the
404
405
        only function that should loop over EngineCoreOutputs.

406
407
        If you need to touch every element of the batch, do it from
        within the loop below.
408
409
        """

410
        request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = []
411
        reqs_to_abort: list[str] = []
412
413
414
415
416
417
418
419
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
420
421
422
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
423

424
            new_token_ids = engine_core_output.new_token_ids
425
            pooling_output = engine_core_output.pooling_output
426
            finish_reason = engine_core_output.finish_reason
427
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
428
            kv_transfer_params = engine_core_output.kv_transfer_params
429
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
430
            req_state.is_prefilling = False
431

432
433
434
435
436
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
437
438
                    new_token_ids, finish_reason == FinishReason.STOP
                )
439
440
441
442
443
444
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
445
                req_state.logprobs_processor.update_from_output(engine_core_output)
446
447

            # 4) Create and handle RequestOutput objects.
448
            if request_output := req_state.make_request_output(
449
450
451
452
453
454
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
            ):
455
456
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
457
                    req_state.queue.put(request_output)
458
459
460
461
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

462
463
464
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
465
466
467
468
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
469
470
471
472
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
473

474
                # Track per-request stats
475
476
477
                self._update_stats_from_finished(
                    req_state, finish_reason, iteration_stats
                )
478
                if self.tracer:
479
                    self.do_tracing(engine_core_output, req_state, iteration_stats)
480
481
        self.lora_states.update_iteration_stats(iteration_stats)

482
483
484
485
486
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

487
488
489
490
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
491
        iteration_stats: IterationStats | None,
492
    ) -> None:
493
494
495
496
497
498
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
499
        prompt_length = length_from_prompt_token_ids_or_embeds(
500
501
502
503
504
505
506
507
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
508
            metrics = req_state.stats
509
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
510
511
512
513
514
515
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
516
517
                metrics.first_token_latency,
            )
518
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
519
520
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
521
            span.set_attribute(
522
523
524
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
525
            span.set_attribute(
526
527
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
528
            span.set_attribute(
529
530
531
532
533
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
534
535

            # meta
536
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
537
            if req_state.top_p:
538
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
539
            if req_state.max_tokens_param:
540
541
542
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
543
            if req_state.temperature:
544
545
546
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
547
            if req_state.n:
548
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
549

550
551
552
553
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
554
555
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
556
    ):
557
558
559
        if iteration_stats is None:
            return

560
561
        lora_stats = self.lora_states.get_stats(req_state)

562
563
        assert engine_core_timestamp is not None
        assert req_state.stats is not None
564
565
566
567
568
569
570
571
572
573
574
575
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
            lora_stats,
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
576
577
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
578
    ):
579
580
581
582
583
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
584
585
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
586
            num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
587
588
                req_state.prompt_token_ids, req_state.prompt_embeds
            ),
589
            max_tokens_param=req_state.max_tokens_param,
590
591
            req_stats=req_state.stats,
        )
592
        self.lora_states.finish_request(req_state)
593
594

        ParentRequest.observe_finished_request(
595
596
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )