output_processor.py 23.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, Optional, Union, cast
8

9
10
11
12
import torch

from vllm.outputs import (CompletionOutput, PoolingOutput,
                          PoolingRequestOutput, RequestOutput)
13
from vllm.sampling_params import RequestOutputKind
14
15
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
                          extract_trace_context)
16
from vllm.transformers_utils.tokenizer import AnyTokenizer
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
19
20
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
21
from vllm.v1.engine.parallel_sampling import ParentRequest
22
23
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
                                   RequestStateStats)
24
25


26
27
28
29
30
31
32
33
34
35
36
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
37
38
        self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
                                    Exception]] = None
39
40
        self.ready = asyncio.Event()

41
42
    def put(self, output: Union[RequestOutput, PoolingRequestOutput,
                                Exception]) -> None:
43
44
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
45
46
            self.output = output
            self.ready.set()
47
        elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
48
49
50
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
51

52
    async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
53
        """Get operation blocks on put event."""
54
55
56
57
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
58
59
        if isinstance(output, Exception):
            raise output
60
61
        return output

62
63
    def get_nowait(
            self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
64
        """Non-blocking get operation."""
65
66
67
68
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
69
70
        if isinstance(output, Exception):
            raise output
71
72
73
        return output


74
75
@dataclass
class OutputProcessorOutput:
76
    request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
77
    reqs_to_abort: list[str]
78
79
80
81
82
83
84


class RequestState:

    def __init__(
        self,
        request_id: str,
85
86
        parent_req: Optional[ParentRequest],
        request_index: int,
87
        lora_name: Optional[str],
88
        output_kind: RequestOutputKind,
89
        prompt: Optional[str],
90
91
        prompt_token_ids: Optional[list[int]],
        prompt_embeds: Optional[torch.Tensor],
92
93
        logprobs_processor: Optional[LogprobsProcessor],
        detokenizer: Optional[IncrementalDetokenizer],
94
        max_tokens_param: Optional[int],
95
        arrival_time: float,
96
        queue: Optional[RequestOutputCollector],
97
        log_stats: bool,
98
99
100
        top_p: Optional[float] = None,
        n: Optional[int] = None,
        temperature: Optional[float] = None,
101
102
    ):
        self.request_id = request_id
103
104
        self.parent_req = parent_req
        self.request_index = request_index
105
        self.lora_name = lora_name
106
        self.output_kind = output_kind
107
108
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
109
110
111
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
            self.prompt_token_ids, self.prompt_embeds)
112
        self.logprobs_processor = logprobs_processor
113
        self.detokenizer = detokenizer
114
        self.max_tokens_param = max_tokens_param
115
116
117
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
118
119
        self.is_prefilling = True
        self.queue = queue
120
        self.num_cached_tokens = 0
121

122
123
        self.stats = RequestStateStats(
            arrival_time=arrival_time) if log_stats else None
124

125
126
127
128
129
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
130
        prompt: Optional[str],
131
132
        parent_req: Optional[ParentRequest],
        request_index: int,
133
        queue: Optional[RequestOutputCollector],
134
        log_stats: bool,
135
    ) -> "RequestState":
136
137
138
139
140
141
142
143
144
145
146
147
148
149

        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
150
151
152
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
153
154
155
156
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
157
158
159
            top_p = None
            n = None
            temperature = None
160
161
162
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

163
164
        return cls(
            request_id=request.request_id,
165
166
            parent_req=parent_req,
            request_index=request_index,
167
168
            lora_name=(request.lora_request.name
                       if request.lora_request is not None else None),
169
            output_kind=output_kind,
170
            prompt=prompt,
171
            prompt_token_ids=request.prompt_token_ids,
172
            prompt_embeds=request.prompt_embeds,
173
174
175
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
176
177
178
            top_p=top_p,
            n=n,
            temperature=temperature,
179
            arrival_time=request.arrival_time,
180
            queue=queue,
181
            log_stats=log_stats,
182
183
        )

184
185
186
    def make_request_output(
        self,
        new_token_ids: list[int],
187
        pooling_output: Optional[torch.Tensor],
188
189
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
Robert Shaw's avatar
Robert Shaw committed
190
        kv_transfer_params: Optional[dict[str, Any]] = None,
191
    ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
192
193

        finished = finish_reason is not None
194
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
195

196
        if not finished and final_only:
197
198
199
            # Only the final output is required in FINAL_ONLY mode.
            return None

200
        request_id = self.request_id
201
202
203
204
205
206
207
208
        if pooling_output is not None:
            return self._new_request_output(
                request_id, [self._new_pooling_output(pooling_output)],
                finished)

        output = self._new_completion_output(new_token_ids, finish_reason,
                                             stop_reason)

209
        if self.parent_req is None:
210
            outputs = [output]
211
212
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
213
                request_id, output)
214
215
            if not outputs:
                return None
216

Robert Shaw's avatar
Robert Shaw committed
217
        return self._new_request_output(request_id, outputs, finished,
218
                                        kv_transfer_params)
219
220
221
222

    def _new_request_output(
        self,
        request_id: str,
223
        outputs: Union[list[CompletionOutput], list[PoolingOutput]],
224
        finished: bool,
Robert Shaw's avatar
Robert Shaw committed
225
        kv_transfer_params: Optional[dict[str, Any]] = None,
226
227
    ) -> Union[RequestOutput, PoolingRequestOutput]:

228
229
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
230
            assert len(outputs) == 1
231
232
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
233
234
            return PoolingRequestOutput(
                request_id=request_id,
235
                outputs=first_output,
236
237
238
239
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
240
241
242
243
244
245
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

246
247
248
249
250
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

251
252
253
254
255
256
257
258
259
        return RequestOutput(request_id=request_id,
                             prompt=self.prompt,
                             prompt_token_ids=prompt_token_ids,
                             prompt_logprobs=prompt_logprobs,
                             outputs=cast(list[CompletionOutput], outputs),
                             finished=finished,
                             kv_transfer_params=kv_transfer_params,
                             num_cached_tokens=self.num_cached_tokens,
                             metrics=self.stats)
260
261
262
263
264
265
266
267

    def _new_completion_output(
        self,
        token_ids: list[int],
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
    ) -> CompletionOutput:

268
269
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
            logprobs = logprobs[-len(token_ids):]

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
            stop_reason=stop_reason if finished else None)

292
293
294
295
296
297
298
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:

        return PoolingOutput(data=pooling_output)

299
300
301
302

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

303
    def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
304
305
        self.log_stats = log_stats
        self.tokenizer = tokenizer
306
        self.request_states: dict[str, RequestState] = {}
307
        self.parent_requests: dict[str, ParentRequest] = {}
308
        self.lora_states = LoRARequestStates()
309
        self.tracer: Optional[Tracer] = None
310
311
312
313
314
315
316

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

317
318
319
320
321
322
323
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

324
325
    def abort_requests(
        self,
326
327
328
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
329
        for request_id in request_ids:
330
331
332
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
                self.lora_states.abort_request(req_state)
333
                request_ids_to_abort.append(request_id)
334
335
336
                # Produce final abort output.
                if req_state.queue is not None and (
                        request_output := req_state.make_request_output(
337
338
339
340
341
342
343
344
                            new_token_ids=[],
                            # Set pooling_output is not None to
                            # correctly enter the abort pooling branch
                            pooling_output=torch.randn(0, device="cpu")
                            if req_state.detokenizer is None else None,
                            finish_reason=FinishReason.ABORT,
                            stop_reason=None,
                            kv_transfer_params=None)):
345
346
347
348
349
350
351
352
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
353
        return request_ids_to_abort
354
355
356
357

    def add_request(
        self,
        request: EngineCoreRequest,
358
        prompt: Optional[str],
359
360
        parent_req: Optional[ParentRequest] = None,
        request_index: int = 0,
361
        queue: Optional[RequestOutputCollector] = None,
362
363
364
365
366
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

367
        req_state = RequestState.from_new_request(tokenizer=self.tokenizer,
368
369
370
371
372
373
                                                  request=request,
                                                  prompt=prompt,
                                                  parent_req=parent_req,
                                                  request_index=request_index,
                                                  queue=queue,
                                                  log_stats=self.log_stats)
374
375
        self.request_states[request_id] = req_state
        self.lora_states.add_request(req_state)
376
377
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
378
379
380

    def process_outputs(
        self,
381
        engine_core_outputs: list[EngineCoreOutput],
382
        engine_core_timestamp: Optional[float] = None,
383
        iteration_stats: Optional[IterationStats] = None,
384
385
386
387
388
389
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
390
            * If there is a queue (for usage with AsyncLLM),
391
392
393
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

394
            * If there is no queue (for usage with LLMEngine),
395
396
              return a list of RequestOutput objects.

397
        NOTE FOR DEVELOPERS
398

399
        vLLM V1 minimizes the number of python loops over the full
400
        batch to ensure system overheads are minimized. This is the
401
402
        only function that should loop over EngineCoreOutputs.

403
404
        If you need to touch every element of the batch, do it from
        within the loop below.
405
406
        """

407
408
        request_outputs: Union[list[RequestOutput],
                               list[PoolingRequestOutput]] = []
409
        reqs_to_abort: list[str] = []
410
411
412
413
414
415
416
417
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
418
419
420
            self._update_stats_from_output(req_state, engine_core_output,
                                           engine_core_timestamp,
                                           iteration_stats)
421

422
            new_token_ids = engine_core_output.new_token_ids
423
            pooling_output = engine_core_output.pooling_output
424
            finish_reason = engine_core_output.finish_reason
425
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
426
            kv_transfer_params = engine_core_output.kv_transfer_params
427
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
428
            req_state.is_prefilling = False
429

430
431
432
433
434
435
436
437
438
439
440
441
442
443
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
                    new_token_ids, finish_reason == FinishReason.STOP)
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
                req_state.logprobs_processor.update_from_output(
                    engine_core_output)
444
445

            # 4) Create and handle RequestOutput objects.
446
            if request_output := req_state.make_request_output(
447
                    new_token_ids, pooling_output, finish_reason, stop_reason,
448
                    kv_transfer_params):
449
450
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
451
                    req_state.queue.put(request_output)
452
453
454
455
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

456
457
458
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
459
460
461
462
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
463
464
465
466
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
467

468
469
470
                # Track per-request stats
                self._update_stats_from_finished(req_state, finish_reason,
                                                 iteration_stats)
471
472
473
                if self.tracer:
                    self.do_tracing(engine_core_output, req_state,
                                    iteration_stats)
474
475
        self.lora_states.update_iteration_stats(iteration_stats)

476
477
478
479
480
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

481
482
483
484
485
486
487
488
489
    def do_tracing(self, engine_core_output: EngineCoreOutput,
                   req_state: RequestState,
                   iteration_stats: Optional[IterationStats]) -> None:
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
490
491
        prompt_length = length_from_prompt_token_ids_or_embeds(
            req_state.prompt_token_ids, req_state.prompt_embeds)
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        with (self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as span):
            metrics = req_state.stats
            e2e_time = iteration_stats.iteration_timestamp - \
                       metrics.arrival_time
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
                metrics.first_token_latency)
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
                               queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
511
                               prompt_length)
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                               metrics.num_generation_tokens)
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL,
                prefill_time)
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE,
                decode_time)
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE,
                inference_time)

            # meta
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
                               req_state.request_id)
            if req_state.top_p:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
                                   req_state.top_p)
            if req_state.max_tokens_param:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
                                   req_state.max_tokens_param)
            if req_state.temperature:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
                                   req_state.temperature)
            if req_state.n:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
                                   req_state.n)

540
541
542
543
544
545
546
    def _update_stats_from_output(self, req_state: RequestState,
                                  engine_core_output: EngineCoreOutput,
                                  engine_core_timestamp: Optional[float],
                                  iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

547
548
        lora_stats = self.lora_states.get_stats(req_state)

549
550
551
552
553
554
        assert engine_core_timestamp is not None
        assert req_state.stats is not None
        iteration_stats.update_from_output(engine_core_output,
                                           engine_core_timestamp,
                                           req_state.is_prefilling,
                                           req_state.prompt_len,
555
                                           req_state.stats, lora_stats)
556
557
558
559
560
561
562
563
564

    def _update_stats_from_finished(self, req_state: RequestState,
                                    finish_reason: Optional[FinishReason],
                                    iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
565
566
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
567
568
            num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
                req_state.prompt_token_ids, req_state.prompt_embeds),
569
            max_tokens_param=req_state.max_tokens_param,
570
            req_stats=req_state.stats)
571
        self.lora_states.finish_request(req_state)
572
573
574
575

        ParentRequest.observe_finished_request(
            req_state.parent_req, iteration_stats,
            req_state.stats.num_generation_tokens)