output_processor.py 30.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections import defaultdict, deque
6
from collections.abc import Iterable
7
from dataclasses import dataclass
8
from typing import Any, cast
9

10
import numpy as np
11
12
import torch

13
from vllm.lora.request import LoRARequest
14
from vllm.outputs import (
15
    STREAM_FINISHED,
16
17
18
19
20
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
21
from vllm.sampling_params import RequestOutputKind
22
from vllm.tokenizers import TokenizerLike
23
24
25
26
27
28
from vllm.tracing import (
    SpanAttributes,
    SpanKind,
    extract_trace_context,
    instrument_manual,
)
29
from vllm.utils import length_from_prompt_token_ids_or_embeds
30
31
32
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
33
from vllm.v1.engine.parallel_sampling import ParentRequest
34
35
36
37
38
39
from vllm.v1.metrics.stats import (
    IterationStats,
    LoRARequestStates,
    RequestStateStats,
    SchedulerStats,
)
40

41
42
43
# shared empty CPU tensor used as a placeholder pooling output
EMPTY_CPU_TENSOR = torch.empty(0, device="cpu")

44

45
46
47
48
49
50
51
52
53
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

54
    def __init__(self, output_kind: RequestOutputKind, request_id: str):
55
        self.aggregate = output_kind == RequestOutputKind.DELTA
56
        self.request_id = request_id
57
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
58
59
        self.ready = asyncio.Event()

60
61
        self._input_stream_task: asyncio.Task | None = None

62
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
63
64
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
65
66
            self.output = output
            self.ready.set()
67
68
69
        elif isinstance(self.output, RequestOutput) and isinstance(
            output, RequestOutput
        ):
70
71
72
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
73
74
75
76
        elif isinstance(self.output, PoolingRequestOutput) and isinstance(
            output, PoolingRequestOutput
        ):
            self.output = output
77

78
    async def get(self) -> RequestOutput | PoolingRequestOutput:
79
        """Get operation blocks on put event."""
80
81
82
83
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
84
85
        if isinstance(output, Exception):
            raise output
86
87
        return output

88
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
89
        """Non-blocking get operation."""
90
91
92
93
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
94
95
        if isinstance(output, Exception):
            raise output
96
97
        return output

98
99
100
101
102
103
104
105
106
107
    def close(self):
        if self._input_stream_task is not None:
            self._input_stream_task.cancel()
        self._input_stream_task = None

    def __del__(self):
        if (task := self._input_stream_task) is not None:
            task.get_loop().call_soon_threadsafe(task.cancel)
            self._input_stream_task = None

108

109
110
@dataclass
class OutputProcessorOutput:
111
    request_outputs: list[RequestOutput | PoolingRequestOutput]
112
    reqs_to_abort: list[str]
113
114


115
116
117
118
119
120
121
122
123
124
125
126
127
128
@dataclass
class StreamingUpdate:
    """Streaming input update data for output processor.

    Contains the incremental prompt data to be applied to a request state
    when the current sub-request completes.
    """

    prompt: str | None
    prompt_token_ids: list[int] | None
    arrival_time: float
    final: bool = False


129
130
131
132
class RequestState:
    def __init__(
        self,
        request_id: str,
133
        external_req_id: str,
134
        parent_req: ParentRequest | None,
135
        request_index: int,
136
        lora_request: LoRARequest | None,
137
        output_kind: RequestOutputKind,
138
139
140
141
142
143
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
144
        arrival_time: float,
145
        queue: RequestOutputCollector | None,
146
        log_stats: bool,
147
        stream_interval: int,
148
149
150
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
151
        stream_input: bool = False,
152
153
    ):
        self.request_id = request_id
154
        self.external_req_id = external_req_id
155
156
        self.parent_req = parent_req
        self.request_index = request_index
157
158
        self.lora_request = lora_request
        self.lora_name = lora_request.lora_name if lora_request is not None else None
159
        self.output_kind = output_kind
160
161
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
162
163
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
164
165
            self.prompt_token_ids, self.prompt_embeds
        )
166
        self.logprobs_processor = logprobs_processor
167
        self.detokenizer = detokenizer
168
        self.max_tokens_param = max_tokens_param
169
170
171
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
172
173
        self.is_prefilling = True
        self.queue = queue
174
        self.num_cached_tokens = 0
175

176
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
177

178
179
180
181
        # Stream Interval
        self.stream_interval = stream_interval
        self.sent_tokens_offset = 0  # Offset of sent tokens

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        # Streaming input queue
        self.streaming_input = stream_input
        self.input_chunk_queue: deque[StreamingUpdate] | None = (
            deque() if stream_input else None
        )

    def apply_streaming_update(self, update: StreamingUpdate) -> None:
        # Apply the update to the request state.
        self.streaming_input = not update.final
        # TODO also include relevant output tokens in new prompt here
        #     (match scheduler behavior).
        if update.prompt:
            self.prompt = (
                (self.prompt + update.prompt) if self.prompt else update.prompt
            )
        if self.prompt_token_ids:
            self.prompt_token_ids.extend(update.prompt_token_ids or ())
        else:
            self.prompt_token_ids = update.prompt_token_ids or []
        assert self.prompt_token_ids is not None
        self.prompt_len = len(self.prompt_token_ids)
        if self.stats is not None:
            self.stats.arrival_time = update.arrival_time
        self.is_prefilling = True

207
208
209
    @classmethod
    def from_new_request(
        cls,
210
        tokenizer: TokenizerLike | None,
211
        request: EngineCoreRequest,
212
213
        prompt: str | None,
        parent_req: ParentRequest | None,
214
        request_index: int,
215
        queue: RequestOutputCollector | None,
216
        log_stats: bool,
217
        stream_interval: int,
218
    ) -> "RequestState":
219
220
221
222
223
224
225
226
227
228
229
230
231
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
232
233
234
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
235
236
237
238
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
239
240
241
            top_p = None
            n = None
            temperature = None
242
243
244
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

245
        assert request.external_req_id is not None
246
247
        return cls(
            request_id=request.request_id,
248
            external_req_id=request.external_req_id,
249
250
            parent_req=parent_req,
            request_index=request_index,
251
            lora_request=request.lora_request,
252
            output_kind=output_kind,
253
            prompt=prompt,
254
            prompt_token_ids=request.prompt_token_ids,
255
            prompt_embeds=request.prompt_embeds,
256
257
258
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
259
260
261
            top_p=top_p,
            n=n,
            temperature=temperature,
262
            arrival_time=request.arrival_time,
263
            queue=queue,
264
            log_stats=log_stats,
265
            stream_interval=stream_interval,
266
            stream_input=request.resumable,
267
268
        )

269
270
271
    def make_request_output(
        self,
        new_token_ids: list[int],
272
273
274
275
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
276
        routed_experts: np.ndarray | None = None,
277
    ) -> RequestOutput | PoolingRequestOutput | None:
278
        finished = finish_reason is not None
279
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
280

281
        if not finished and final_only:
282
283
284
            # Only the final output is required in FINAL_ONLY mode.
            return None

285
286
287
288
289
290
291
292
293
294
        if self.stream_interval > 1:
            assert self.detokenizer is not None

            # Send output request only when
            # 1. It has finished, or
            # 2. It is the first token, or
            # 3. It has reached the stream interval number of tokens
            if not (
                finished
                or self.sent_tokens_offset == 0
295
                or self.detokenizer.num_output_tokens() - self.sent_tokens_offset
296
297
298
299
300
301
302
303
304
305
                >= self.stream_interval
            ):
                return None

            if self.output_kind == RequestOutputKind.DELTA:
                # Send tokens from the offset in DELTA mode, otherwise all
                # tokens are sent.
                new_token_ids = self.detokenizer.output_token_ids[
                    self.sent_tokens_offset :
                ]
306
                self.sent_tokens_offset = self.detokenizer.num_output_tokens()
307

308
309
        external_req_id = self.external_req_id

310
311
        if pooling_output is not None:
            return self._new_request_output(
312
313
314
                external_req_id,
                [self._new_pooling_output(pooling_output)],
                finished,
315
            )
316

317
318
319
        output = self._new_completion_output(
            new_token_ids, finish_reason, stop_reason, routed_experts
        )
320

321
        if self.parent_req is None:
322
            outputs = [output]
323
        else:
324
            outputs, finished = self.parent_req.get_outputs(self.request_id, output)
325
326
            if not outputs:
                return None
327
            external_req_id = self.parent_req.external_req_id
328

329
        return self._new_request_output(
330
            external_req_id, outputs, finished, kv_transfer_params
331
        )
332
333
334

    def _new_request_output(
        self,
335
        external_req_id: str,
336
        outputs: list[CompletionOutput] | list[PoolingOutput],
337
        finished: bool,
338
339
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
340
341
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
342
            assert len(outputs) == 1
343
344
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
345
            return PoolingRequestOutput(
346
                request_id=external_req_id,
347
                outputs=first_output,
348
                num_cached_tokens=self.num_cached_tokens,
349
350
351
352
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
353
354
355
356
357
358
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

359
360
361
362
363
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

364
        return RequestOutput(
365
            request_id=external_req_id,  # request_id is what was provided externally
366
            lora_request=self.lora_request,
367
368
369
370
371
372
373
374
375
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
376
377
378
379

    def _new_completion_output(
        self,
        token_ids: list[int],
380
381
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
382
        routed_experts: np.ndarray | None = None,
383
    ) -> CompletionOutput:
384
385
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
386
387
388
389
390
391
392
393
394
395
396
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
397
            logprobs = logprobs[-len(token_ids) :]
398
399
400
401
402

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
403
            routed_experts=routed_experts,
404
405
406
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
407
408
            stop_reason=stop_reason if finished else None,
        )
409

410
    def _new_pooling_output(self, pooling_output: torch.Tensor) -> PoolingOutput:
411
412
        return PoolingOutput(data=pooling_output)

413
414
415
416

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

417
    def __init__(
418
419
420
421
        self,
        tokenizer: TokenizerLike | None,
        log_stats: bool,
        stream_interval: int = 1,
422
    ):
423
424
        self.log_stats = log_stats
        self.tokenizer = tokenizer
425
        self.stream_interval = stream_interval
426
        self.request_states: dict[str, RequestState] = {}
427
        self.parent_requests: dict[str, ParentRequest] = {}
428
        self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
429
        self.lora_states = LoRARequestStates(log_stats)
430
        self.tracing_enabled: bool = False
431
432
        self._requests_drained = asyncio.Event()
        self._requests_drained.set()
433
434
435
436
437
438
439

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

440
441
442
443
444
    async def wait_for_requests_to_drain(self) -> None:
        if not self.request_states:
            return
        await self._requests_drained.wait()

445
446
447
448
449
450
451
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]:
        """Abort a list of requests.

        The request_ids may be either external request IDs (those passed to
        InputProcessor.process_inputs()) or internal request IDs (those randomly
        generated when creating the EngineCoreRequest).

        If an external request ID is provided, and that external request ID
        was used for multiple requests, all requests associated with that external
        request ID are aborted.

        In the case of parallel sampling, a request ID may be used to identify
        a parent request, in which case the associated child requests are aborted
        also.
        """
        internal_req_ids = []
468
        for request_id in request_ids:
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
            if internal:
                # Internal ID - this may be a parent request
                internal_req_ids.append(request_id)

                # Remove internal ID from the external->internal mapping
                if req_state := self.request_states.get(request_id):
                    external_req_id = req_state.external_req_id
                    internal_ids = self.external_req_ids[external_req_id]
                    internal_ids.remove(request_id)
                    if not internal_ids:
                        del self.external_req_ids[external_req_id]
            elif internal_ids := self.external_req_ids.pop(request_id, []):
                # External ID - abort all requests in the external->internal mapping
                internal_req_ids.extend(internal_ids)

        request_ids_to_abort = []
        for request_id in internal_req_ids:
486
487
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
488
                self.lora_states.request_finished(request_id, req_state.lora_name)
489
                request_ids_to_abort.append(request_id)
490
491
                # Produce final abort output.
                if req_state.queue is not None and (
492
493
494
495
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
496
                        pooling_output=EMPTY_CPU_TENSOR
497
498
499
500
501
502
503
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
504
505
506
507
508
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
509
                    child_reqs = self.abort_requests(child_reqs, internal=True)
510
511
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
512
513
        if not self.request_states:
            self._requests_drained.set()
514
        return request_ids_to_abort
515
516
517
518

    def add_request(
        self,
        request: EngineCoreRequest,
519
520
        prompt: str | None,
        parent_req: ParentRequest | None = None,
521
        request_index: int = 0,
522
        queue: RequestOutputCollector | None = None,
523
524
    ) -> None:
        request_id = request.request_id
525
526
527
528
        req_state = self.request_states.get(request_id)
        if req_state is not None:
            self._update_streaming_request_state(req_state, request, prompt)
            return
529

530
531
532
533
534
535
536
537
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
538
            stream_interval=self.stream_interval,
539
        )
540
541
        if self._requests_drained.is_set():
            self._requests_drained.clear()
542
        self.request_states[request_id] = req_state
543
544
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
545

546
547
548
        # Track the external_req_id -> [internal_req_id, ...] mapping
        self.external_req_ids[req_state.external_req_id].append(request_id)

549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
    def _update_streaming_request_state(
        self, req_state: RequestState, request: EngineCoreRequest, prompt: str | None
    ) -> None:
        """Queue a streaming update instead of immediately applying it."""
        if not request.resumable:
            # Final request - just mark completion, don't add its dummy tokens.
            if req_state.input_chunk_queue is None:
                # Engine already finished - emit final output and clean up.
                self._finish_request(req_state)
                if req_state.queue is not None:
                    # Emit a final output with finished=True
                    # to unblock the generate() loop.
                    req_state.queue.put(STREAM_FINISHED)
            elif req_state.input_chunk_queue:
                req_state.input_chunk_queue[-1].final = True
            else:
                req_state.streaming_input = False
            return

        update = StreamingUpdate(
            prompt=prompt,
            prompt_token_ids=request.prompt_token_ids,
            arrival_time=request.arrival_time,
        )

        # Apply request updates now if the last input already completed.
        if req_state.input_chunk_queue is None:
            req_state.apply_streaming_update(update)
            req_state.input_chunk_queue = deque()
        else:
            # Queue the streaming update otherwise.
            req_state.input_chunk_queue.append(update)

582
583
    def process_outputs(
        self,
584
        engine_core_outputs: list[EngineCoreOutput],
585
586
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
587
588
589
590
591
592
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
593
            * If there is a queue (for usage with AsyncLLM),
594
595
596
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

597
            * If there is no queue (for usage with LLMEngine),
598
599
              return a list of RequestOutput objects.

600
        NOTE FOR DEVELOPERS
601

602
        vLLM V1 minimizes the number of python loops over the full
603
        batch to ensure system overheads are minimized. This is the
604
605
        only function that should loop over EngineCoreOutputs.

606
607
        If you need to touch every element of the batch, do it from
        within the loop below.
608
609
        """

610
        request_outputs: list[RequestOutput | PoolingRequestOutput] = []
611
        reqs_to_abort: list[str] = []
612
613
614
615
616
617
618
619
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
620
621
622
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
623

624
            new_token_ids = engine_core_output.new_token_ids
625
            pooling_output = engine_core_output.pooling_output
626
            finish_reason = engine_core_output.finish_reason
627
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
628
            kv_transfer_params = engine_core_output.kv_transfer_params
629
            routed_experts = engine_core_output.routed_experts
630
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
631
            req_state.is_prefilling = False
632

633
634
635
636
637
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
638
639
                    new_token_ids, finish_reason == FinishReason.STOP
                )
640
641
642
643
644
645
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
646
                req_state.logprobs_processor.update_from_output(engine_core_output)
647
648

            # 4) Create and handle RequestOutput objects.
649
            if request_output := req_state.make_request_output(
650
651
652
653
654
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
655
                routed_experts,
656
            ):
657
658
659
                if req_state.streaming_input:
                    request_output.finished = False

660
661
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
662
                    req_state.queue.put(request_output)
663
664
665
666
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

667
668
            # Free completed requests.
            if finish_reason is not None:
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
                if req_state.streaming_input:
                    if req_state.input_chunk_queue:
                        update = req_state.input_chunk_queue.popleft()
                        req_state.apply_streaming_update(update)
                    else:
                        req_state.input_chunk_queue = None
                else:
                    self._finish_request(req_state)
                    if not engine_core_output.finished:
                        # If req not finished in EngineCore, but Detokenizer
                        # detected stop string, abort needed in EngineCore.
                        reqs_to_abort.append(req_id)

                    # Track per-request stats
                    self._update_stats_from_finished(
                        req_state, finish_reason, iteration_stats
                    )
686
                    if self.tracing_enabled:
687
                        self.do_tracing(engine_core_output, req_state, iteration_stats)
688

689
690
691
692
693
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
    def _finish_request(self, req_state: RequestState) -> None:
        req_id = req_state.request_id
        self.request_states.pop(req_id)

        internal_ids = self.external_req_ids[req_state.external_req_id]
        internal_ids.remove(req_id)
        if not internal_ids:
            del self.external_req_ids[req_state.external_req_id]

        # Remove parent request if applicable.
        parent_req = req_state.parent_req
        if parent_req and not parent_req.child_requests:
            self.parent_requests.pop(parent_req.request_id, None)

        if not self.request_states:
            self._requests_drained.set()

711
712
713
    def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
        self.lora_states.update_scheduler_stats(scheduler_stats)

714
715
716
717
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
718
        iteration_stats: IterationStats | None,
719
    ) -> None:
720
721
722
        assert req_state.stats is not None
        assert iteration_stats is not None

723
724
        metrics = req_state.stats
        arrival_time_ns = int(metrics.arrival_time * 1e9)
725
        trace_context = extract_trace_context(engine_core_output.trace_headers)
726
        prompt_length = length_from_prompt_token_ids_or_embeds(
727
728
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759

        # Calculate timing metrics
        e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
        queued_time = metrics.scheduled_ts - metrics.queued_ts
        prefill_time = metrics.first_token_ts - metrics.scheduled_ts
        decode_time = metrics.last_token_ts - metrics.first_token_ts
        inference_time = metrics.last_token_ts - metrics.scheduled_ts

        # Build attributes dict
        attributes: dict[str, Any] = {
            SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN: (
                metrics.first_token_latency
            ),
            SpanAttributes.GEN_AI_LATENCY_E2E: e2e_time,
            SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE: queued_time,
            SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS: prompt_length,
            SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS: (
                metrics.num_generation_tokens
            ),
            SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL: prefill_time,
            SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE: decode_time,
            SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE: inference_time,
            SpanAttributes.GEN_AI_REQUEST_ID: req_state.external_req_id,
        }

        # Add optional request parameters
        if req_state.top_p:
            attributes[SpanAttributes.GEN_AI_REQUEST_TOP_P] = req_state.top_p
        if req_state.max_tokens_param:
            attributes[SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS] = (
                req_state.max_tokens_param
760
            )
761
762
763
        if req_state.temperature:
            attributes[SpanAttributes.GEN_AI_REQUEST_TEMPERATURE] = (
                req_state.temperature
764
            )
765
766
        if req_state.n:
            attributes[SpanAttributes.GEN_AI_REQUEST_N] = req_state.n
767

768
769
770
771
772
773
774
        instrument_manual(
            span_name="llm_request",
            start_time=arrival_time_ns,
            attributes=attributes,
            context=trace_context,
            kind=SpanKind.SERVER,
        )
775

776
777
778
779
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
780
781
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
782
    ):
783
784
785
786
787
        if iteration_stats is None:
            return

        assert engine_core_timestamp is not None
        assert req_state.stats is not None
788
789
790
791
792
793
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
794
795
            self.lora_states,
            req_state.lora_name,
796
797
798
799
800
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
801
802
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
803
    ):
804
805
806
807
808
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
809
810
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
811
            num_prompt_tokens=req_state.prompt_len,
812
            max_tokens_param=req_state.max_tokens_param,
813
            req_stats=req_state.stats,
814
            num_cached_tokens=req_state.num_cached_tokens,
815
        )
816
        self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
817
818

        ParentRequest.observe_finished_request(
819
820
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )