output_processor.py 27.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections import defaultdict
6
from collections.abc import Iterable
7
from dataclasses import dataclass
8
from typing import Any, cast
9

10
import numpy as np
11
12
import torch

13
from vllm.lora.request import LoRARequest
14
15
16
17
18
19
from vllm.outputs import (
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
20
from vllm.sampling_params import RequestOutputKind
21
from vllm.tokenizers import TokenizerLike
22
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
23
from vllm.utils import length_from_prompt_token_ids_or_embeds
24
25
26
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
27
from vllm.v1.engine.parallel_sampling import ParentRequest
28
29
30
31
32
33
from vllm.v1.metrics.stats import (
    IterationStats,
    LoRARequestStates,
    RequestStateStats,
    SchedulerStats,
)
34

35
36
37
# shared empty CPU tensor used as a placeholder pooling output
EMPTY_CPU_TENSOR = torch.empty(0, device="cpu")

38

39
40
41
42
43
44
45
46
47
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

48
    def __init__(self, output_kind: RequestOutputKind, request_id: str):
49
        self.aggregate = output_kind == RequestOutputKind.DELTA
50
        self.request_id = request_id
51
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
52
53
        self.ready = asyncio.Event()

54
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
55
56
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
57
58
            self.output = output
            self.ready.set()
59
60
61
        elif isinstance(self.output, RequestOutput) and isinstance(
            output, RequestOutput
        ):
62
63
64
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
65
66
67
68
        elif isinstance(self.output, PoolingRequestOutput) and isinstance(
            output, PoolingRequestOutput
        ):
            self.output = output
69

70
    async def get(self) -> RequestOutput | PoolingRequestOutput:
71
        """Get operation blocks on put event."""
72
73
74
75
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
76
77
        if isinstance(output, Exception):
            raise output
78
79
        return output

80
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
81
        """Non-blocking get operation."""
82
83
84
85
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
86
87
        if isinstance(output, Exception):
            raise output
88
89
90
        return output


91
92
@dataclass
class OutputProcessorOutput:
93
    request_outputs: list[RequestOutput | PoolingRequestOutput]
94
    reqs_to_abort: list[str]
95
96
97
98
99
100


class RequestState:
    def __init__(
        self,
        request_id: str,
101
        external_req_id: str,
102
        parent_req: ParentRequest | None,
103
        request_index: int,
104
        lora_request: LoRARequest | None,
105
        output_kind: RequestOutputKind,
106
107
108
109
110
111
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
112
        arrival_time: float,
113
        queue: RequestOutputCollector | None,
114
        log_stats: bool,
115
        stream_interval: int,
116
117
118
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
119
120
    ):
        self.request_id = request_id
121
        self.external_req_id = external_req_id
122
123
        self.parent_req = parent_req
        self.request_index = request_index
124
125
        self.lora_request = lora_request
        self.lora_name = lora_request.lora_name if lora_request is not None else None
126
        self.output_kind = output_kind
127
128
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
129
130
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
131
132
            self.prompt_token_ids, self.prompt_embeds
        )
133
        self.logprobs_processor = logprobs_processor
134
        self.detokenizer = detokenizer
135
        self.max_tokens_param = max_tokens_param
136
137
138
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
139
140
        self.is_prefilling = True
        self.queue = queue
141
        self.num_cached_tokens = 0
142

143
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
144

145
146
147
148
        # Stream Interval
        self.stream_interval = stream_interval
        self.sent_tokens_offset = 0  # Offset of sent tokens

149
150
151
    @classmethod
    def from_new_request(
        cls,
152
        tokenizer: TokenizerLike | None,
153
        request: EngineCoreRequest,
154
155
        prompt: str | None,
        parent_req: ParentRequest | None,
156
        request_index: int,
157
        queue: RequestOutputCollector | None,
158
        log_stats: bool,
159
        stream_interval: int,
160
    ) -> "RequestState":
161
162
163
164
165
166
167
168
169
170
171
172
173
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
174
175
176
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
177
178
179
180
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
181
182
183
            top_p = None
            n = None
            temperature = None
184
185
186
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

187
        assert request.external_req_id is not None
188
189
        return cls(
            request_id=request.request_id,
190
            external_req_id=request.external_req_id,
191
192
            parent_req=parent_req,
            request_index=request_index,
193
            lora_request=request.lora_request,
194
            output_kind=output_kind,
195
            prompt=prompt,
196
            prompt_token_ids=request.prompt_token_ids,
197
            prompt_embeds=request.prompt_embeds,
198
199
200
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
201
202
203
            top_p=top_p,
            n=n,
            temperature=temperature,
204
            arrival_time=request.arrival_time,
205
            queue=queue,
206
            log_stats=log_stats,
207
            stream_interval=stream_interval,
208
209
        )

210
211
212
    def make_request_output(
        self,
        new_token_ids: list[int],
213
214
215
216
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
217
        routed_experts: np.ndarray | None = None,
218
    ) -> RequestOutput | PoolingRequestOutput | None:
219
        finished = finish_reason is not None
220
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
221

222
        if not finished and final_only:
223
224
225
            # Only the final output is required in FINAL_ONLY mode.
            return None

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        if self.stream_interval > 1:
            assert self.detokenizer is not None

            # Send output request only when
            # 1. It has finished, or
            # 2. It is the first token, or
            # 3. It has reached the stream interval number of tokens
            if not (
                finished
                or self.sent_tokens_offset == 0
                or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
                >= self.stream_interval
            ):
                return None

            if self.output_kind == RequestOutputKind.DELTA:
                # Send tokens from the offset in DELTA mode, otherwise all
                # tokens are sent.
                new_token_ids = self.detokenizer.output_token_ids[
                    self.sent_tokens_offset :
                ]
                self.sent_tokens_offset = len(self.detokenizer.output_token_ids)

249
250
        external_req_id = self.external_req_id

251
252
        if pooling_output is not None:
            return self._new_request_output(
253
254
255
                external_req_id,
                [self._new_pooling_output(pooling_output)],
                finished,
256
            )
257

258
259
260
        output = self._new_completion_output(
            new_token_ids, finish_reason, stop_reason, routed_experts
        )
261

262
        if self.parent_req is None:
263
            outputs = [output]
264
        else:
265
            outputs, finished = self.parent_req.get_outputs(self.request_id, output)
266
267
            if not outputs:
                return None
268
            external_req_id = self.parent_req.external_req_id
269

270
        return self._new_request_output(
271
            external_req_id, outputs, finished, kv_transfer_params
272
        )
273
274
275

    def _new_request_output(
        self,
276
        external_req_id: str,
277
        outputs: list[CompletionOutput] | list[PoolingOutput],
278
        finished: bool,
279
280
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
281
282
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
283
            assert len(outputs) == 1
284
285
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
286
            return PoolingRequestOutput(
287
                request_id=external_req_id,
288
                outputs=first_output,
289
                num_cached_tokens=self.num_cached_tokens,
290
291
292
293
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
294
295
296
297
298
299
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

300
301
302
303
304
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

305
        return RequestOutput(
306
            request_id=external_req_id,  # request_id is what was provided externally
307
            lora_request=self.lora_request,
308
309
310
311
312
313
314
315
316
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
317
318
319
320

    def _new_completion_output(
        self,
        token_ids: list[int],
321
322
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
323
        routed_experts: np.ndarray | None = None,
324
    ) -> CompletionOutput:
325
326
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
327
328
329
330
331
332
333
334
335
336
337
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
338
            logprobs = logprobs[-len(token_ids) :]
339
340
341
342
343

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
344
            routed_experts=routed_experts,
345
346
347
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
348
349
            stop_reason=stop_reason if finished else None,
        )
350

351
    def _new_pooling_output(self, pooling_output: torch.Tensor) -> PoolingOutput:
352
353
        return PoolingOutput(data=pooling_output)

354
355
356
357

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

358
    def __init__(
359
360
361
362
        self,
        tokenizer: TokenizerLike | None,
        log_stats: bool,
        stream_interval: int = 1,
363
    ):
364
365
        self.log_stats = log_stats
        self.tokenizer = tokenizer
366
        self.stream_interval = stream_interval
367
        self.request_states: dict[str, RequestState] = {}
368
        self.parent_requests: dict[str, ParentRequest] = {}
369
        self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
370
        self.lora_states = LoRARequestStates(log_stats)
371
        self.tracer: Tracer | None = None
372
373
        self._requests_drained = asyncio.Event()
        self._requests_drained.set()
374
375
376
377
378
379
380

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

381
382
383
384
385
    async def wait_for_requests_to_drain(self) -> None:
        if not self.request_states:
            return
        await self._requests_drained.wait()

386
387
388
389
390
391
392
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]:
        """Abort a list of requests.

        The request_ids may be either external request IDs (those passed to
        InputProcessor.process_inputs()) or internal request IDs (those randomly
        generated when creating the EngineCoreRequest).

        If an external request ID is provided, and that external request ID
        was used for multiple requests, all requests associated with that external
        request ID are aborted.

        In the case of parallel sampling, a request ID may be used to identify
        a parent request, in which case the associated child requests are aborted
        also.
        """

        internal_req_ids = []
410
        for request_id in request_ids:
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
            if internal:
                # Internal ID - this may be a parent request
                internal_req_ids.append(request_id)

                # Remove internal ID from the external->internal mapping
                if req_state := self.request_states.get(request_id):
                    external_req_id = req_state.external_req_id
                    internal_ids = self.external_req_ids[external_req_id]
                    internal_ids.remove(request_id)
                    if not internal_ids:
                        del self.external_req_ids[external_req_id]
            elif internal_ids := self.external_req_ids.pop(request_id, []):
                # External ID - abort all requests in the external->internal mapping
                internal_req_ids.extend(internal_ids)

        request_ids_to_abort = []
        for request_id in internal_req_ids:
428
429
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
430
                self.lora_states.request_finished(request_id, req_state.lora_name)
431
                request_ids_to_abort.append(request_id)
432
433
                # Produce final abort output.
                if req_state.queue is not None and (
434
435
436
437
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
438
                        pooling_output=EMPTY_CPU_TENSOR
439
440
441
442
443
444
445
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
446
447
448
449
450
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
451
                    child_reqs = self.abort_requests(child_reqs, internal=True)
452
453
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
454
455
        if not self.request_states:
            self._requests_drained.set()
456
        return request_ids_to_abort
457
458
459
460

    def add_request(
        self,
        request: EngineCoreRequest,
461
462
        prompt: str | None,
        parent_req: ParentRequest | None = None,
463
        request_index: int = 0,
464
        queue: RequestOutputCollector | None = None,
465
466
467
468
469
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

470
471
472
473
474
475
476
477
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
478
            stream_interval=self.stream_interval,
479
        )
480
481
        if self._requests_drained.is_set():
            self._requests_drained.clear()
482
        self.request_states[request_id] = req_state
483
484
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
485

486
487
488
        # Track the external_req_id -> [internal_req_id, ...] mapping
        self.external_req_ids[req_state.external_req_id].append(request_id)

489
490
    def process_outputs(
        self,
491
        engine_core_outputs: list[EngineCoreOutput],
492
493
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
494
495
496
497
498
499
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
500
            * If there is a queue (for usage with AsyncLLM),
501
502
503
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

504
            * If there is no queue (for usage with LLMEngine),
505
506
              return a list of RequestOutput objects.

507
        NOTE FOR DEVELOPERS
508

509
        vLLM V1 minimizes the number of python loops over the full
510
        batch to ensure system overheads are minimized. This is the
511
512
        only function that should loop over EngineCoreOutputs.

513
514
        If you need to touch every element of the batch, do it from
        within the loop below.
515
516
        """

517
        request_outputs: list[RequestOutput | PoolingRequestOutput] = []
518
        reqs_to_abort: list[str] = []
519
520
521
522
523
524
525
526
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
527
528
529
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
530

531
            new_token_ids = engine_core_output.new_token_ids
532
            pooling_output = engine_core_output.pooling_output
533
            finish_reason = engine_core_output.finish_reason
534
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
535
            kv_transfer_params = engine_core_output.kv_transfer_params
536
            routed_experts = engine_core_output.routed_experts
537
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
538
            req_state.is_prefilling = False
539

540
541
542
543
544
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
545
546
                    new_token_ids, finish_reason == FinishReason.STOP
                )
547
548
549
550
551
552
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
553
                req_state.logprobs_processor.update_from_output(engine_core_output)
554
555

            # 4) Create and handle RequestOutput objects.
556
            if request_output := req_state.make_request_output(
557
558
559
560
561
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
562
                routed_experts,
563
            ):
564
565
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
566
                    req_state.queue.put(request_output)
567
568
569
570
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

571
572
573
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
574
575
576
577
578
579

                internal_ids = self.external_req_ids[req_state.external_req_id]
                internal_ids.remove(req_id)
                if not internal_ids:
                    del self.external_req_ids[req_state.external_req_id]

580
581
582
583
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
584
585
                if not self.request_states:
                    self._requests_drained.set()
586
587
588
589
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
590

591
                # Track per-request stats
592
593
594
                self._update_stats_from_finished(
                    req_state, finish_reason, iteration_stats
                )
595
                if self.tracer:
596
                    self.do_tracing(engine_core_output, req_state, iteration_stats)
597

598
599
600
601
602
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

603
604
605
    def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
        self.lora_states.update_scheduler_stats(scheduler_stats)

606
607
608
609
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
610
        iteration_stats: IterationStats | None,
611
    ) -> None:
612
613
614
615
616
617
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
618
        prompt_length = length_from_prompt_token_ids_or_embeds(
619
620
621
622
623
624
625
626
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
627
            metrics = req_state.stats
628
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
629
630
631
632
633
634
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
635
636
                metrics.first_token_latency,
            )
637
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
638
639
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
640
            span.set_attribute(
641
642
643
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
644
            span.set_attribute(
645
646
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
647
            span.set_attribute(
648
649
650
651
652
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
653
654

            # meta
655
656
657
            span.set_attribute(
                SpanAttributes.GEN_AI_REQUEST_ID, req_state.external_req_id
            )
658
            if req_state.top_p:
659
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
660
            if req_state.max_tokens_param:
661
662
663
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
664
            if req_state.temperature:
665
666
667
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
668
            if req_state.n:
669
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
670

671
672
673
674
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
675
676
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
677
    ):
678
679
680
681
682
        if iteration_stats is None:
            return

        assert engine_core_timestamp is not None
        assert req_state.stats is not None
683
684
685
686
687
688
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
689
690
            self.lora_states,
            req_state.lora_name,
691
692
693
694
695
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
696
697
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
698
    ):
699
700
701
702
703
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
704
705
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
706
            num_prompt_tokens=req_state.prompt_len,
707
            max_tokens_param=req_state.max_tokens_param,
708
            req_stats=req_state.stats,
709
            num_cached_tokens=req_state.num_cached_tokens,
710
        )
711
        self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
712
713

        ParentRequest.observe_finished_request(
714
715
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )