output_processor.py 31.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections import defaultdict, deque
6
from collections.abc import Iterable
7
from dataclasses import dataclass
8
from typing import Any, cast
9

10
import numpy as np
11
12
import torch

13
from vllm.lora.request import LoRARequest
14
from vllm.outputs import (
15
    STREAM_FINISHED,
16
17
18
19
20
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
21
from vllm.sampling_params import RequestOutputKind
22
from vllm.tokenizers import TokenizerLike
23
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
24
from vllm.utils import length_from_prompt_token_ids_or_embeds
25
26
27
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
28
from vllm.v1.engine.parallel_sampling import ParentRequest
29
30
31
32
33
34
from vllm.v1.metrics.stats import (
    IterationStats,
    LoRARequestStates,
    RequestStateStats,
    SchedulerStats,
)
35

36
37
38
# shared empty CPU tensor used as a placeholder pooling output
EMPTY_CPU_TENSOR = torch.empty(0, device="cpu")

39

40
41
42
43
44
45
46
47
48
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

49
    def __init__(self, output_kind: RequestOutputKind, request_id: str):
50
        self.aggregate = output_kind == RequestOutputKind.DELTA
51
        self.request_id = request_id
52
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
53
54
        self.ready = asyncio.Event()

55
56
        self._input_stream_task: asyncio.Task | None = None

57
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
58
59
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
60
61
            self.output = output
            self.ready.set()
62
63
64
        elif isinstance(self.output, RequestOutput) and isinstance(
            output, RequestOutput
        ):
65
66
67
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
68
69
70
71
        elif isinstance(self.output, PoolingRequestOutput) and isinstance(
            output, PoolingRequestOutput
        ):
            self.output = output
72

73
    async def get(self) -> RequestOutput | PoolingRequestOutput:
74
        """Get operation blocks on put event."""
75
76
77
78
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
79
80
        if isinstance(output, Exception):
            raise output
81
82
        return output

83
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
84
        """Non-blocking get operation."""
85
86
87
88
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
89
90
        if isinstance(output, Exception):
            raise output
91
92
        return output

93
94
95
96
97
98
99
100
101
102
    def close(self):
        if self._input_stream_task is not None:
            self._input_stream_task.cancel()
        self._input_stream_task = None

    def __del__(self):
        if (task := self._input_stream_task) is not None:
            task.get_loop().call_soon_threadsafe(task.cancel)
            self._input_stream_task = None

103

104
105
@dataclass
class OutputProcessorOutput:
106
    request_outputs: list[RequestOutput | PoolingRequestOutput]
107
    reqs_to_abort: list[str]
108
109


110
111
112
113
114
115
116
117
118
119
120
121
122
123
@dataclass
class StreamingUpdate:
    """Streaming input update data for output processor.

    Contains the incremental prompt data to be applied to a request state
    when the current sub-request completes.
    """

    prompt: str | None
    prompt_token_ids: list[int] | None
    arrival_time: float
    final: bool = False


124
125
126
127
class RequestState:
    def __init__(
        self,
        request_id: str,
128
        external_req_id: str,
129
        parent_req: ParentRequest | None,
130
        request_index: int,
131
        lora_request: LoRARequest | None,
132
        output_kind: RequestOutputKind,
133
134
135
136
137
138
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
139
        arrival_time: float,
140
        queue: RequestOutputCollector | None,
141
        log_stats: bool,
142
        stream_interval: int,
143
144
145
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
146
        stream_input: bool = False,
147
148
    ):
        self.request_id = request_id
149
        self.external_req_id = external_req_id
150
151
        self.parent_req = parent_req
        self.request_index = request_index
152
153
        self.lora_request = lora_request
        self.lora_name = lora_request.lora_name if lora_request is not None else None
154
        self.output_kind = output_kind
155
156
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
157
158
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
159
160
            self.prompt_token_ids, self.prompt_embeds
        )
161
        self.logprobs_processor = logprobs_processor
162
        self.detokenizer = detokenizer
163
        self.max_tokens_param = max_tokens_param
164
165
166
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
167
168
        self.is_prefilling = True
        self.queue = queue
169
        self.num_cached_tokens = 0
170

171
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
172

173
174
175
176
        # Stream Interval
        self.stream_interval = stream_interval
        self.sent_tokens_offset = 0  # Offset of sent tokens

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        # Streaming input queue
        self.streaming_input = stream_input
        self.input_chunk_queue: deque[StreamingUpdate] | None = (
            deque() if stream_input else None
        )

    def apply_streaming_update(self, update: StreamingUpdate) -> None:
        # Apply the update to the request state.
        self.streaming_input = not update.final
        # TODO also include relevant output tokens in new prompt here
        #     (match scheduler behavior).
        if update.prompt:
            self.prompt = (
                (self.prompt + update.prompt) if self.prompt else update.prompt
            )
        if self.prompt_token_ids:
            self.prompt_token_ids.extend(update.prompt_token_ids or ())
        else:
            self.prompt_token_ids = update.prompt_token_ids or []
        assert self.prompt_token_ids is not None
        self.prompt_len = len(self.prompt_token_ids)
        if self.stats is not None:
            self.stats.arrival_time = update.arrival_time
        self.is_prefilling = True

202
203
204
    @classmethod
    def from_new_request(
        cls,
205
        tokenizer: TokenizerLike | None,
206
        request: EngineCoreRequest,
207
208
        prompt: str | None,
        parent_req: ParentRequest | None,
209
        request_index: int,
210
        queue: RequestOutputCollector | None,
211
        log_stats: bool,
212
        stream_interval: int,
213
    ) -> "RequestState":
214
215
216
217
218
219
220
221
222
223
224
225
226
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
227
228
229
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
230
231
232
233
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
234
235
236
            top_p = None
            n = None
            temperature = None
237
238
239
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

240
        assert request.external_req_id is not None
241
242
        return cls(
            request_id=request.request_id,
243
            external_req_id=request.external_req_id,
244
245
            parent_req=parent_req,
            request_index=request_index,
246
            lora_request=request.lora_request,
247
            output_kind=output_kind,
248
            prompt=prompt,
249
            prompt_token_ids=request.prompt_token_ids,
250
            prompt_embeds=request.prompt_embeds,
251
252
253
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
254
255
256
            top_p=top_p,
            n=n,
            temperature=temperature,
257
            arrival_time=request.arrival_time,
258
            queue=queue,
259
            log_stats=log_stats,
260
            stream_interval=stream_interval,
261
            stream_input=request.resumable,
262
263
        )

264
265
266
    def make_request_output(
        self,
        new_token_ids: list[int],
267
268
269
270
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
271
        routed_experts: np.ndarray | None = None,
272
    ) -> RequestOutput | PoolingRequestOutput | None:
273
        finished = finish_reason is not None
274
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
275

276
        if not finished and final_only:
277
278
279
            # Only the final output is required in FINAL_ONLY mode.
            return None

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        if self.stream_interval > 1:
            assert self.detokenizer is not None

            # Send output request only when
            # 1. It has finished, or
            # 2. It is the first token, or
            # 3. It has reached the stream interval number of tokens
            if not (
                finished
                or self.sent_tokens_offset == 0
                or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
                >= self.stream_interval
            ):
                return None

            if self.output_kind == RequestOutputKind.DELTA:
                # Send tokens from the offset in DELTA mode, otherwise all
                # tokens are sent.
                new_token_ids = self.detokenizer.output_token_ids[
                    self.sent_tokens_offset :
                ]
                self.sent_tokens_offset = len(self.detokenizer.output_token_ids)

303
304
        external_req_id = self.external_req_id

305
306
        if pooling_output is not None:
            return self._new_request_output(
307
308
309
                external_req_id,
                [self._new_pooling_output(pooling_output)],
                finished,
310
            )
311

312
313
314
        output = self._new_completion_output(
            new_token_ids, finish_reason, stop_reason, routed_experts
        )
315

316
        if self.parent_req is None:
317
            outputs = [output]
318
        else:
319
            outputs, finished = self.parent_req.get_outputs(self.request_id, output)
320
321
            if not outputs:
                return None
322
            external_req_id = self.parent_req.external_req_id
323

324
        return self._new_request_output(
325
            external_req_id, outputs, finished, kv_transfer_params
326
        )
327
328
329

    def _new_request_output(
        self,
330
        external_req_id: str,
331
        outputs: list[CompletionOutput] | list[PoolingOutput],
332
        finished: bool,
333
334
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
335
336
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
337
            assert len(outputs) == 1
338
339
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
340
            return PoolingRequestOutput(
341
                request_id=external_req_id,
342
                outputs=first_output,
343
                num_cached_tokens=self.num_cached_tokens,
344
345
346
347
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
348
349
350
351
352
353
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

354
355
356
357
358
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

359
        return RequestOutput(
360
            request_id=external_req_id,  # request_id is what was provided externally
361
            lora_request=self.lora_request,
362
363
364
365
366
367
368
369
370
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
371
372
373
374

    def _new_completion_output(
        self,
        token_ids: list[int],
375
376
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
377
        routed_experts: np.ndarray | None = None,
378
    ) -> CompletionOutput:
379
380
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
381
382
383
384
385
386
387
388
389
390
391
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
392
            logprobs = logprobs[-len(token_ids) :]
393
394
395
396
397

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
398
            routed_experts=routed_experts,
399
400
401
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
402
403
            stop_reason=stop_reason if finished else None,
        )
404

405
    def _new_pooling_output(self, pooling_output: torch.Tensor) -> PoolingOutput:
406
407
        return PoolingOutput(data=pooling_output)

408
409
410
411

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

412
    def __init__(
413
414
415
416
        self,
        tokenizer: TokenizerLike | None,
        log_stats: bool,
        stream_interval: int = 1,
417
    ):
418
419
        self.log_stats = log_stats
        self.tokenizer = tokenizer
420
        self.stream_interval = stream_interval
421
        self.request_states: dict[str, RequestState] = {}
422
        self.parent_requests: dict[str, ParentRequest] = {}
423
        self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
424
        self.lora_states = LoRARequestStates(log_stats)
425
        self.tracer: Tracer | None = None
426
427
        self._requests_drained = asyncio.Event()
        self._requests_drained.set()
428
429
430
431
432
433
434

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

435
436
437
438
439
    async def wait_for_requests_to_drain(self) -> None:
        if not self.request_states:
            return
        await self._requests_drained.wait()

440
441
442
443
444
445
446
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]:
        """Abort a list of requests.

        The request_ids may be either external request IDs (those passed to
        InputProcessor.process_inputs()) or internal request IDs (those randomly
        generated when creating the EngineCoreRequest).

        If an external request ID is provided, and that external request ID
        was used for multiple requests, all requests associated with that external
        request ID are aborted.

        In the case of parallel sampling, a request ID may be used to identify
        a parent request, in which case the associated child requests are aborted
        also.
        """
        internal_req_ids = []
463
        for request_id in request_ids:
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
            if internal:
                # Internal ID - this may be a parent request
                internal_req_ids.append(request_id)

                # Remove internal ID from the external->internal mapping
                if req_state := self.request_states.get(request_id):
                    external_req_id = req_state.external_req_id
                    internal_ids = self.external_req_ids[external_req_id]
                    internal_ids.remove(request_id)
                    if not internal_ids:
                        del self.external_req_ids[external_req_id]
            elif internal_ids := self.external_req_ids.pop(request_id, []):
                # External ID - abort all requests in the external->internal mapping
                internal_req_ids.extend(internal_ids)

        request_ids_to_abort = []
        for request_id in internal_req_ids:
481
482
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
483
                self.lora_states.request_finished(request_id, req_state.lora_name)
484
                request_ids_to_abort.append(request_id)
485
486
                # Produce final abort output.
                if req_state.queue is not None and (
487
488
489
490
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
491
                        pooling_output=EMPTY_CPU_TENSOR
492
493
494
495
496
497
498
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
499
500
501
502
503
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
504
                    child_reqs = self.abort_requests(child_reqs, internal=True)
505
506
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
507
508
        if not self.request_states:
            self._requests_drained.set()
509
        return request_ids_to_abort
510
511
512
513

    def add_request(
        self,
        request: EngineCoreRequest,
514
515
        prompt: str | None,
        parent_req: ParentRequest | None = None,
516
        request_index: int = 0,
517
        queue: RequestOutputCollector | None = None,
518
519
    ) -> None:
        request_id = request.request_id
520
521
522
523
        req_state = self.request_states.get(request_id)
        if req_state is not None:
            self._update_streaming_request_state(req_state, request, prompt)
            return
524

525
526
527
528
529
530
531
532
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
533
            stream_interval=self.stream_interval,
534
        )
535
536
        if self._requests_drained.is_set():
            self._requests_drained.clear()
537
        self.request_states[request_id] = req_state
538
539
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
540

541
542
543
        # Track the external_req_id -> [internal_req_id, ...] mapping
        self.external_req_ids[req_state.external_req_id].append(request_id)

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    def _update_streaming_request_state(
        self, req_state: RequestState, request: EngineCoreRequest, prompt: str | None
    ) -> None:
        """Queue a streaming update instead of immediately applying it."""
        if not request.resumable:
            # Final request - just mark completion, don't add its dummy tokens.
            if req_state.input_chunk_queue is None:
                # Engine already finished - emit final output and clean up.
                self._finish_request(req_state)
                if req_state.queue is not None:
                    # Emit a final output with finished=True
                    # to unblock the generate() loop.
                    req_state.queue.put(STREAM_FINISHED)
            elif req_state.input_chunk_queue:
                req_state.input_chunk_queue[-1].final = True
            else:
                req_state.streaming_input = False
            return

        update = StreamingUpdate(
            prompt=prompt,
            prompt_token_ids=request.prompt_token_ids,
            arrival_time=request.arrival_time,
        )

        # Apply request updates now if the last input already completed.
        if req_state.input_chunk_queue is None:
            req_state.apply_streaming_update(update)
            req_state.input_chunk_queue = deque()
        else:
            # Queue the streaming update otherwise.
            req_state.input_chunk_queue.append(update)

577
578
    def process_outputs(
        self,
579
        engine_core_outputs: list[EngineCoreOutput],
580
581
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
582
583
584
585
586
587
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
588
            * If there is a queue (for usage with AsyncLLM),
589
590
591
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

592
            * If there is no queue (for usage with LLMEngine),
593
594
              return a list of RequestOutput objects.

595
        NOTE FOR DEVELOPERS
596

597
        vLLM V1 minimizes the number of python loops over the full
598
        batch to ensure system overheads are minimized. This is the
599
600
        only function that should loop over EngineCoreOutputs.

601
602
        If you need to touch every element of the batch, do it from
        within the loop below.
603
604
        """

605
        request_outputs: list[RequestOutput | PoolingRequestOutput] = []
606
        reqs_to_abort: list[str] = []
607
608
609
610
611
612
613
614
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
615
616
617
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
618

619
            new_token_ids = engine_core_output.new_token_ids
620
            pooling_output = engine_core_output.pooling_output
621
            finish_reason = engine_core_output.finish_reason
622
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
623
            kv_transfer_params = engine_core_output.kv_transfer_params
624
            routed_experts = engine_core_output.routed_experts
625
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
626
            req_state.is_prefilling = False
627

628
629
630
631
632
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
633
634
                    new_token_ids, finish_reason == FinishReason.STOP
                )
635
636
637
638
639
640
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
641
                req_state.logprobs_processor.update_from_output(engine_core_output)
642
643

            # 4) Create and handle RequestOutput objects.
644
            if request_output := req_state.make_request_output(
645
646
647
648
649
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
650
                routed_experts,
651
            ):
652
653
654
                if req_state.streaming_input:
                    request_output.finished = False

655
656
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
657
                    req_state.queue.put(request_output)
658
659
660
661
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

662
663
            # Free completed requests.
            if finish_reason is not None:
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
                if req_state.streaming_input:
                    if req_state.input_chunk_queue:
                        update = req_state.input_chunk_queue.popleft()
                        req_state.apply_streaming_update(update)
                    else:
                        req_state.input_chunk_queue = None
                else:
                    self._finish_request(req_state)
                    if not engine_core_output.finished:
                        # If req not finished in EngineCore, but Detokenizer
                        # detected stop string, abort needed in EngineCore.
                        reqs_to_abort.append(req_id)

                    # Track per-request stats
                    self._update_stats_from_finished(
                        req_state, finish_reason, iteration_stats
                    )
                    if self.tracer:
                        self.do_tracing(engine_core_output, req_state, iteration_stats)
683

684
685
686
687
688
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
    def _finish_request(self, req_state: RequestState) -> None:
        req_id = req_state.request_id
        self.request_states.pop(req_id)

        internal_ids = self.external_req_ids[req_state.external_req_id]
        internal_ids.remove(req_id)
        if not internal_ids:
            del self.external_req_ids[req_state.external_req_id]

        # Remove parent request if applicable.
        parent_req = req_state.parent_req
        if parent_req and not parent_req.child_requests:
            self.parent_requests.pop(parent_req.request_id, None)

        if not self.request_states:
            self._requests_drained.set()

706
707
708
    def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
        self.lora_states.update_scheduler_stats(scheduler_stats)

709
710
711
712
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
713
        iteration_stats: IterationStats | None,
714
    ) -> None:
715
716
717
718
719
720
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
721
        prompt_length = length_from_prompt_token_ids_or_embeds(
722
723
724
725
726
727
728
729
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
730
            metrics = req_state.stats
731
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
732
733
734
735
736
737
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
738
739
                metrics.first_token_latency,
            )
740
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
741
742
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
743
            span.set_attribute(
744
745
746
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
747
            span.set_attribute(
748
749
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
750
            span.set_attribute(
751
752
753
754
755
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
756
757

            # meta
758
759
760
            span.set_attribute(
                SpanAttributes.GEN_AI_REQUEST_ID, req_state.external_req_id
            )
761
            if req_state.top_p:
762
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
763
            if req_state.max_tokens_param:
764
765
766
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
767
            if req_state.temperature:
768
769
770
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
771
            if req_state.n:
772
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
773

774
775
776
777
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
778
779
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
780
    ):
781
782
783
784
785
        if iteration_stats is None:
            return

        assert engine_core_timestamp is not None
        assert req_state.stats is not None
786
787
788
789
790
791
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
792
793
            self.lora_states,
            req_state.lora_name,
794
795
796
797
798
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
799
800
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
801
    ):
802
803
804
805
806
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
807
808
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
809
            num_prompt_tokens=req_state.prompt_len,
810
            max_tokens_param=req_state.max_tokens_param,
811
            req_stats=req_state.stats,
812
            num_cached_tokens=req_state.num_cached_tokens,
813
        )
814
        self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
815
816

        ParentRequest.observe_finished_request(
817
818
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )