output_processor.py 27.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections import defaultdict
6
from collections.abc import Iterable
7
from dataclasses import dataclass
8
from typing import Any, cast
9

10
11
import torch

12
from vllm.lora.request import LoRARequest
13
14
15
16
17
18
from vllm.outputs import (
    CompletionOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
19
from vllm.sampling_params import RequestOutputKind
20
from vllm.tokenizers import TokenizerLike
21
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
22
from vllm.utils import length_from_prompt_token_ids_or_embeds
23
24
25
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
26
from vllm.v1.engine.parallel_sampling import ParentRequest
27
28
29
30
31
32
from vllm.v1.metrics.stats import (
    IterationStats,
    LoRARequestStates,
    RequestStateStats,
    SchedulerStats,
)
33
34


35
36
37
38
39
40
41
42
43
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

44
    def __init__(self, output_kind: RequestOutputKind, request_id: str):
45
        self.aggregate = output_kind == RequestOutputKind.DELTA
46
        self.request_id = request_id
47
        self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
48
49
        self.ready = asyncio.Event()

50
    def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
51
52
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
53
54
            self.output = output
            self.ready.set()
55
56
57
        elif isinstance(self.output, RequestOutput) and isinstance(
            output, RequestOutput
        ):
58
59
60
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
61
62
63
64
        elif isinstance(self.output, PoolingRequestOutput) and isinstance(
            output, PoolingRequestOutput
        ):
            self.output = output
65

66
    async def get(self) -> RequestOutput | PoolingRequestOutput:
67
        """Get operation blocks on put event."""
68
69
70
71
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
72
73
        if isinstance(output, Exception):
            raise output
74
75
        return output

76
    def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None:
77
        """Non-blocking get operation."""
78
79
80
81
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
82
83
        if isinstance(output, Exception):
            raise output
84
85
86
        return output


87
88
@dataclass
class OutputProcessorOutput:
89
    request_outputs: list[RequestOutput | PoolingRequestOutput]
90
    reqs_to_abort: list[str]
91
92
93
94
95
96


class RequestState:
    def __init__(
        self,
        request_id: str,
97
        external_req_id: str,
98
        parent_req: ParentRequest | None,
99
        request_index: int,
100
        lora_request: LoRARequest | None,
101
        output_kind: RequestOutputKind,
102
103
104
105
106
107
        prompt: str | None,
        prompt_token_ids: list[int] | None,
        prompt_embeds: torch.Tensor | None,
        logprobs_processor: LogprobsProcessor | None,
        detokenizer: IncrementalDetokenizer | None,
        max_tokens_param: int | None,
108
        arrival_time: float,
109
        queue: RequestOutputCollector | None,
110
        log_stats: bool,
111
        stream_interval: int,
112
113
114
        top_p: float | None = None,
        n: int | None = None,
        temperature: float | None = None,
115
116
    ):
        self.request_id = request_id
117
        self.external_req_id = external_req_id
118
119
        self.parent_req = parent_req
        self.request_index = request_index
120
121
        self.lora_request = lora_request
        self.lora_name = lora_request.lora_name if lora_request is not None else None
122
        self.output_kind = output_kind
123
124
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
125
126
        self.prompt_embeds = prompt_embeds
        self.prompt_len = length_from_prompt_token_ids_or_embeds(
127
128
            self.prompt_token_ids, self.prompt_embeds
        )
129
        self.logprobs_processor = logprobs_processor
130
        self.detokenizer = detokenizer
131
        self.max_tokens_param = max_tokens_param
132
133
134
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
135
136
        self.is_prefilling = True
        self.queue = queue
137
        self.num_cached_tokens = 0
138

139
        self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
140

141
142
143
144
        # Stream Interval
        self.stream_interval = stream_interval
        self.sent_tokens_offset = 0  # Offset of sent tokens

145
146
147
    @classmethod
    def from_new_request(
        cls,
148
        tokenizer: TokenizerLike | None,
149
        request: EngineCoreRequest,
150
151
        prompt: str | None,
        parent_req: ParentRequest | None,
152
        request_index: int,
153
        queue: RequestOutputCollector | None,
154
        log_stats: bool,
155
        stream_interval: int,
156
    ) -> "RequestState":
157
158
159
160
161
162
163
164
165
166
167
168
169
        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
170
171
172
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
173
174
175
176
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
177
178
179
            top_p = None
            n = None
            temperature = None
180
181
182
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

183
        assert request.external_req_id is not None
184
185
        return cls(
            request_id=request.request_id,
186
            external_req_id=request.external_req_id,
187
188
            parent_req=parent_req,
            request_index=request_index,
189
            lora_request=request.lora_request,
190
            output_kind=output_kind,
191
            prompt=prompt,
192
            prompt_token_ids=request.prompt_token_ids,
193
            prompt_embeds=request.prompt_embeds,
194
195
196
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
197
198
199
            top_p=top_p,
            n=n,
            temperature=temperature,
200
            arrival_time=request.arrival_time,
201
            queue=queue,
202
            log_stats=log_stats,
203
            stream_interval=stream_interval,
204
205
        )

206
207
208
    def make_request_output(
        self,
        new_token_ids: list[int],
209
210
211
212
213
        pooling_output: torch.Tensor | None,
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput | None:
214
        finished = finish_reason is not None
215
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
216

217
        if not finished and final_only:
218
219
220
            # Only the final output is required in FINAL_ONLY mode.
            return None

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        if self.stream_interval > 1:
            assert self.detokenizer is not None

            # Send output request only when
            # 1. It has finished, or
            # 2. It is the first token, or
            # 3. It has reached the stream interval number of tokens
            if not (
                finished
                or self.sent_tokens_offset == 0
                or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
                >= self.stream_interval
            ):
                return None

            if self.output_kind == RequestOutputKind.DELTA:
                # Send tokens from the offset in DELTA mode, otherwise all
                # tokens are sent.
                new_token_ids = self.detokenizer.output_token_ids[
                    self.sent_tokens_offset :
                ]
                self.sent_tokens_offset = len(self.detokenizer.output_token_ids)

244
245
        external_req_id = self.external_req_id

246
247
        if pooling_output is not None:
            return self._new_request_output(
248
249
250
                external_req_id,
                [self._new_pooling_output(pooling_output)],
                finished,
251
            )
252

253
        output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
254

255
        if self.parent_req is None:
256
            outputs = [output]
257
        else:
258
            outputs, finished = self.parent_req.get_outputs(self.request_id, output)
259
260
            if not outputs:
                return None
261
            external_req_id = self.parent_req.external_req_id
262

263
        return self._new_request_output(
264
            external_req_id, outputs, finished, kv_transfer_params
265
        )
266
267
268

    def _new_request_output(
        self,
269
        external_req_id: str,
270
        outputs: list[CompletionOutput] | list[PoolingOutput],
271
        finished: bool,
272
273
        kv_transfer_params: dict[str, Any] | None = None,
    ) -> RequestOutput | PoolingRequestOutput:
274
275
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
276
            assert len(outputs) == 1
277
278
            # Prompt embeddings are currently not supported by pooling requests.
            assert self.prompt_token_ids is not None
279
            return PoolingRequestOutput(
280
                request_id=external_req_id,
281
                outputs=first_output,
282
                num_cached_tokens=self.num_cached_tokens,
283
284
285
286
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
287
288
289
290
291
292
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

293
294
295
296
297
        # If prompt embeds were used, put placeholder prompt token ids
        prompt_token_ids = self.prompt_token_ids
        if prompt_token_ids is None and self.prompt_embeds is not None:
            prompt_token_ids = [0] * len(self.prompt_embeds)

298
        return RequestOutput(
299
            request_id=external_req_id,  # request_id is what was provided externally
300
            lora_request=self.lora_request,
301
302
303
304
305
306
307
308
309
            prompt=self.prompt,
            prompt_token_ids=prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
            outputs=cast(list[CompletionOutput], outputs),
            finished=finished,
            kv_transfer_params=kv_transfer_params,
            num_cached_tokens=self.num_cached_tokens,
            metrics=self.stats,
        )
310
311
312
313

    def _new_completion_output(
        self,
        token_ids: list[int],
314
315
        finish_reason: FinishReason | None,
        stop_reason: int | str | None,
316
    ) -> CompletionOutput:
317
318
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
319
320
321
322
323
324
325
326
327
328
329
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
330
            logprobs = logprobs[-len(token_ids) :]
331
332
333
334
335
336
337
338

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
339
340
            stop_reason=stop_reason if finished else None,
        )
341

342
343
344
345
346
347
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:
        return PoolingOutput(data=pooling_output)

348
349
350
351

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

352
    def __init__(
353
354
355
356
        self,
        tokenizer: TokenizerLike | None,
        log_stats: bool,
        stream_interval: int = 1,
357
    ):
358
359
        self.log_stats = log_stats
        self.tokenizer = tokenizer
360
        self.stream_interval = stream_interval
361
        self.request_states: dict[str, RequestState] = {}
362
        self.parent_requests: dict[str, ParentRequest] = {}
363
        self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
364
        self.lora_states = LoRARequestStates(log_stats)
365
        self.tracer: Tracer | None = None
366
367
        self._requests_drained = asyncio.Event()
        self._requests_drained.set()
368
369
370
371
372
373
374

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

375
376
377
378
379
    async def wait_for_requests_to_drain(self) -> None:
        if not self.request_states:
            return
        await self._requests_drained.wait()

380
381
382
383
384
385
386
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]:
        """Abort a list of requests.

        The request_ids may be either external request IDs (those passed to
        InputProcessor.process_inputs()) or internal request IDs (those randomly
        generated when creating the EngineCoreRequest).

        If an external request ID is provided, and that external request ID
        was used for multiple requests, all requests associated with that external
        request ID are aborted.

        In the case of parallel sampling, a request ID may be used to identify
        a parent request, in which case the associated child requests are aborted
        also.
        """

        internal_req_ids = []
404
        for request_id in request_ids:
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            if internal:
                # Internal ID - this may be a parent request
                internal_req_ids.append(request_id)

                # Remove internal ID from the external->internal mapping
                if req_state := self.request_states.get(request_id):
                    external_req_id = req_state.external_req_id
                    internal_ids = self.external_req_ids[external_req_id]
                    internal_ids.remove(request_id)
                    if not internal_ids:
                        del self.external_req_ids[external_req_id]
            elif internal_ids := self.external_req_ids.pop(request_id, []):
                # External ID - abort all requests in the external->internal mapping
                internal_req_ids.extend(internal_ids)

        request_ids_to_abort = []
        for request_id in internal_req_ids:
422
423
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
424
                self.lora_states.request_finished(request_id, req_state.lora_name)
425
                request_ids_to_abort.append(request_id)
426
427
                # Produce final abort output.
                if req_state.queue is not None and (
428
429
430
431
432
433
434
435
436
437
438
439
                    request_output := req_state.make_request_output(
                        new_token_ids=[],
                        # Set pooling_output is not None to
                        # correctly enter the abort pooling branch
                        pooling_output=torch.randn(0, device="cpu")
                        if req_state.detokenizer is None
                        else None,
                        finish_reason=FinishReason.ABORT,
                        stop_reason=None,
                        kv_transfer_params=None,
                    )
                ):
440
441
442
443
444
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
445
                    child_reqs = self.abort_requests(child_reqs, internal=True)
446
447
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
448
449
        if not self.request_states:
            self._requests_drained.set()
450
        return request_ids_to_abort
451
452
453
454

    def add_request(
        self,
        request: EngineCoreRequest,
455
456
        prompt: str | None,
        parent_req: ParentRequest | None = None,
457
        request_index: int = 0,
458
        queue: RequestOutputCollector | None = None,
459
460
461
462
463
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

464
465
466
467
468
469
470
471
        req_state = RequestState.from_new_request(
            tokenizer=self.tokenizer,
            request=request,
            prompt=prompt,
            parent_req=parent_req,
            request_index=request_index,
            queue=queue,
            log_stats=self.log_stats,
472
            stream_interval=self.stream_interval,
473
        )
474
475
        if self._requests_drained.is_set():
            self._requests_drained.clear()
476
        self.request_states[request_id] = req_state
477
478
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
479

480
481
482
        # Track the external_req_id -> [internal_req_id, ...] mapping
        self.external_req_ids[req_state.external_req_id].append(request_id)

483
484
    def process_outputs(
        self,
485
        engine_core_outputs: list[EngineCoreOutput],
486
487
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
488
489
490
491
492
493
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
494
            * If there is a queue (for usage with AsyncLLM),
495
496
497
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

498
            * If there is no queue (for usage with LLMEngine),
499
500
              return a list of RequestOutput objects.

501
        NOTE FOR DEVELOPERS
502

503
        vLLM V1 minimizes the number of python loops over the full
504
        batch to ensure system overheads are minimized. This is the
505
506
        only function that should loop over EngineCoreOutputs.

507
508
        If you need to touch every element of the batch, do it from
        within the loop below.
509
510
        """

511
        request_outputs: list[RequestOutput | PoolingRequestOutput] = []
512
        reqs_to_abort: list[str] = []
513
514
515
516
517
518
519
520
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
521
522
523
            self._update_stats_from_output(
                req_state, engine_core_output, engine_core_timestamp, iteration_stats
            )
524

525
            new_token_ids = engine_core_output.new_token_ids
526
            pooling_output = engine_core_output.pooling_output
527
            finish_reason = engine_core_output.finish_reason
528
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
529
            kv_transfer_params = engine_core_output.kv_transfer_params
530
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
531
            req_state.is_prefilling = False
532

533
534
535
536
537
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
538
539
                    new_token_ids, finish_reason == FinishReason.STOP
                )
540
541
542
543
544
545
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
546
                req_state.logprobs_processor.update_from_output(engine_core_output)
547
548

            # 4) Create and handle RequestOutput objects.
549
            if request_output := req_state.make_request_output(
550
551
552
553
554
555
                new_token_ids,
                pooling_output,
                finish_reason,
                stop_reason,
                kv_transfer_params,
            ):
556
557
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
558
                    req_state.queue.put(request_output)
559
560
561
562
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

563
564
565
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
566
567
568
569
570
571

                internal_ids = self.external_req_ids[req_state.external_req_id]
                internal_ids.remove(req_id)
                if not internal_ids:
                    del self.external_req_ids[req_state.external_req_id]

572
573
574
575
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
576
577
                if not self.request_states:
                    self._requests_drained.set()
578
579
580
581
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
582

583
                # Track per-request stats
584
585
586
                self._update_stats_from_finished(
                    req_state, finish_reason, iteration_stats
                )
587
                if self.tracer:
588
                    self.do_tracing(engine_core_output, req_state, iteration_stats)
589

590
591
592
593
594
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

595
596
597
    def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
        self.lora_states.update_scheduler_stats(scheduler_stats)

598
599
600
601
    def do_tracing(
        self,
        engine_core_output: EngineCoreOutput,
        req_state: RequestState,
602
        iteration_stats: IterationStats | None,
603
    ) -> None:
604
605
606
607
608
609
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
610
        prompt_length = length_from_prompt_token_ids_or_embeds(
611
612
613
614
615
616
617
618
            req_state.prompt_token_ids, req_state.prompt_embeds
        )
        with self.tracer.start_as_current_span(
            "llm_request",
            kind=SpanKind.SERVER,
            context=trace_context,
            start_time=arrival_time_nano_seconds,
        ) as span:
619
            metrics = req_state.stats
620
            e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
621
622
623
624
625
626
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
627
628
                metrics.first_token_latency,
            )
629
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
630
631
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
632
            span.set_attribute(
633
634
635
                SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                metrics.num_generation_tokens,
            )
636
            span.set_attribute(
637
638
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
            )
639
            span.set_attribute(
640
641
642
643
644
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
            )
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
            )
645
646

            # meta
647
648
649
            span.set_attribute(
                SpanAttributes.GEN_AI_REQUEST_ID, req_state.external_req_id
            )
650
            if req_state.top_p:
651
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
652
            if req_state.max_tokens_param:
653
654
655
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
                )
656
            if req_state.temperature:
657
658
659
                span.set_attribute(
                    SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
                )
660
            if req_state.n:
661
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
662

663
664
665
666
    def _update_stats_from_output(
        self,
        req_state: RequestState,
        engine_core_output: EngineCoreOutput,
667
668
        engine_core_timestamp: float | None,
        iteration_stats: IterationStats | None,
669
    ):
670
671
672
673
674
        if iteration_stats is None:
            return

        assert engine_core_timestamp is not None
        assert req_state.stats is not None
675
676
677
678
679
680
        iteration_stats.update_from_output(
            engine_core_output,
            engine_core_timestamp,
            req_state.is_prefilling,
            req_state.prompt_len,
            req_state.stats,
681
682
            self.lora_states,
            req_state.lora_name,
683
684
685
686
687
        )

    def _update_stats_from_finished(
        self,
        req_state: RequestState,
688
689
        finish_reason: FinishReason | None,
        iteration_stats: IterationStats | None,
690
    ):
691
692
693
694
695
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
696
697
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
698
            num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
699
700
                req_state.prompt_token_ids, req_state.prompt_embeds
            ),
701
            max_tokens_param=req_state.max_tokens_param,
702
            req_stats=req_state.stats,
703
            num_cached_tokens=req_state.num_cached_tokens,
704
        )
705
        self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
706
707

        ParentRequest.observe_finished_request(
708
709
            req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
        )