output_processor.py 22.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, Optional, Union, cast
8

9
10
11
12
import torch

from vllm.outputs import (CompletionOutput, PoolingOutput,
                          PoolingRequestOutput, RequestOutput)
13
from vllm.sampling_params import RequestOutputKind
14
15
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
                          extract_trace_context)
16
17
18
19
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
20
from vllm.v1.engine.parallel_sampling import ParentRequest
21
22
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
                                   RequestStateStats)
23
24


25
26
27
28
29
30
31
32
33
34
35
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
36
37
        self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
                                    Exception]] = None
38
39
        self.ready = asyncio.Event()

40
41
    def put(self, output: Union[RequestOutput, PoolingRequestOutput,
                                Exception]) -> None:
42
43
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
44
45
            self.output = output
            self.ready.set()
46
        elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
47
48
49
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
50

51
    async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
52
        """Get operation blocks on put event."""
53
54
55
56
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
57
58
        if isinstance(output, Exception):
            raise output
59
60
        return output

61
62
    def get_nowait(
            self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
63
        """Non-blocking get operation."""
64
65
66
67
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
68
69
        if isinstance(output, Exception):
            raise output
70
71
72
        return output


73
74
@dataclass
class OutputProcessorOutput:
75
    request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
76
    reqs_to_abort: list[str]
77
78
79
80
81
82
83


class RequestState:

    def __init__(
        self,
        request_id: str,
84
85
        parent_req: Optional[ParentRequest],
        request_index: int,
86
        lora_name: Optional[str],
87
        output_kind: RequestOutputKind,
88
        prompt: Optional[str],
89
        prompt_token_ids: list[int],
90
91
        logprobs_processor: Optional[LogprobsProcessor],
        detokenizer: Optional[IncrementalDetokenizer],
92
        max_tokens_param: Optional[int],
93
        arrival_time: float,
94
        queue: Optional[RequestOutputCollector],
95
        log_stats: bool,
96
97
98
        top_p: Optional[float] = None,
        n: Optional[int] = None,
        temperature: Optional[float] = None,
99
100
    ):
        self.request_id = request_id
101
102
        self.parent_req = parent_req
        self.request_index = request_index
103
        self.lora_name = lora_name
104
        self.output_kind = output_kind
105
106
107
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.prompt_len = len(prompt_token_ids)
108
        self.logprobs_processor = logprobs_processor
109
        self.detokenizer = detokenizer
110
        self.max_tokens_param = max_tokens_param
111
112
113
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
114
115
        self.is_prefilling = True
        self.queue = queue
116
        self.num_cached_tokens = 0
117

118
119
        self.stats = RequestStateStats(
            arrival_time=arrival_time) if log_stats else None
120

121
122
123
124
125
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
126
        prompt: Optional[str],
127
128
        parent_req: Optional[ParentRequest],
        request_index: int,
129
        queue: Optional[RequestOutputCollector],
130
        log_stats: bool,
131
    ) -> "RequestState":
132
133
134
135
136
137
138
139
140
141
142
143
144
145

        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
146
147
148
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
149
150
151
152
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
153
154
155
            top_p = None
            n = None
            temperature = None
156
157
158
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

159
160
        return cls(
            request_id=request.request_id,
161
162
            parent_req=parent_req,
            request_index=request_index,
163
164
            lora_name=(request.lora_request.name
                       if request.lora_request is not None else None),
165
            output_kind=output_kind,
166
            prompt=prompt,
167
            prompt_token_ids=request.prompt_token_ids,
168
169
170
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
171
172
173
            top_p=top_p,
            n=n,
            temperature=temperature,
174
            arrival_time=request.arrival_time,
175
            queue=queue,
176
            log_stats=log_stats,
177
178
        )

179
180
181
    def make_request_output(
        self,
        new_token_ids: list[int],
182
        pooling_output: Optional[torch.Tensor],
183
184
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
Robert Shaw's avatar
Robert Shaw committed
185
        kv_transfer_params: Optional[dict[str, Any]] = None,
186
    ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
187
188

        finished = finish_reason is not None
189
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
190

191
        if not finished and final_only:
192
193
194
            # Only the final output is required in FINAL_ONLY mode.
            return None

195
        request_id = self.request_id
196
197
198
199
200
201
202
203
        if pooling_output is not None:
            return self._new_request_output(
                request_id, [self._new_pooling_output(pooling_output)],
                finished)

        output = self._new_completion_output(new_token_ids, finish_reason,
                                             stop_reason)

204
        if self.parent_req is None:
205
            outputs = [output]
206
207
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
208
                request_id, output)
209
210
            if not outputs:
                return None
211

Robert Shaw's avatar
Robert Shaw committed
212
        return self._new_request_output(request_id, outputs, finished,
213
                                        kv_transfer_params)
214
215
216
217

    def _new_request_output(
        self,
        request_id: str,
218
        outputs: Union[list[CompletionOutput], list[PoolingOutput]],
219
        finished: bool,
Robert Shaw's avatar
Robert Shaw committed
220
        kv_transfer_params: Optional[dict[str, Any]] = None,
221
222
    ) -> Union[RequestOutput, PoolingRequestOutput]:

223
224
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
225
226
227
            assert len(outputs) == 1
            return PoolingRequestOutput(
                request_id=request_id,
228
                outputs=first_output,
229
230
231
232
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
233
234
235
236
237
238
239
240
241
242
243
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=self.prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
244
            outputs=cast(list[CompletionOutput], outputs),
245
            finished=finished,
Robert Shaw's avatar
Robert Shaw committed
246
            kv_transfer_params=kv_transfer_params,
247
            num_cached_tokens=self.num_cached_tokens,
248
249
250
251
252
253
254
255
256
        )

    def _new_completion_output(
        self,
        token_ids: list[int],
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
    ) -> CompletionOutput:

257
258
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
            logprobs = logprobs[-len(token_ids):]

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
            stop_reason=stop_reason if finished else None)

281
282
283
284
285
286
287
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:

        return PoolingOutput(data=pooling_output)

288
289
290
291

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

292
    def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
293
294
        self.log_stats = log_stats
        self.tokenizer = tokenizer
295
        self.request_states: dict[str, RequestState] = {}
296
        self.parent_requests: dict[str, ParentRequest] = {}
297
        self.lora_states = LoRARequestStates()
298
        self.tracer: Optional[Tracer] = None
299
300
301
302
303
304
305

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

306
307
308
309
310
311
312
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

313
314
    def abort_requests(
        self,
315
316
317
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
318
        for request_id in request_ids:
319
320
321
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
                self.lora_states.abort_request(req_state)
322
                request_ids_to_abort.append(request_id)
323
324
325
326
327
328
329
330
331
332
333
334
                # Produce final abort output.
                if req_state.queue is not None and (
                        request_output := req_state.make_request_output(
                            [], None, FinishReason.ABORT, None, None)):
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
335
        return request_ids_to_abort
336
337
338
339

    def add_request(
        self,
        request: EngineCoreRequest,
340
        prompt: Optional[str],
341
342
        parent_req: Optional[ParentRequest] = None,
        request_index: int = 0,
343
        queue: Optional[RequestOutputCollector] = None,
344
345
346
347
348
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

349
        req_state = RequestState.from_new_request(tokenizer=self.tokenizer,
350
351
352
353
354
355
                                                  request=request,
                                                  prompt=prompt,
                                                  parent_req=parent_req,
                                                  request_index=request_index,
                                                  queue=queue,
                                                  log_stats=self.log_stats)
356
357
        self.request_states[request_id] = req_state
        self.lora_states.add_request(req_state)
358
359
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
360
361
362

    def process_outputs(
        self,
363
        engine_core_outputs: list[EngineCoreOutput],
364
        engine_core_timestamp: Optional[float] = None,
365
        iteration_stats: Optional[IterationStats] = None,
366
367
368
369
370
371
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
372
            * If there is a queue (for usage with AsyncLLM),
373
374
375
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

376
            * If there is no queue (for usage with LLMEngine),
377
378
              return a list of RequestOutput objects.

379
        NOTE FOR DEVELOPERS
380

381
        vLLM V1 minimizes the number of python loops over the full
382
        batch to ensure system overheads are minimized. This is the
383
384
        only function that should loop over EngineCoreOutputs.

385
386
        If you need to touch every element of the batch, do it from
        within the loop below.
387
388
        """

389
390
        request_outputs: Union[list[RequestOutput],
                               list[PoolingRequestOutput]] = []
391
        reqs_to_abort: list[str] = []
392
393
394
395
396
397
398
399
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
400
401
402
            self._update_stats_from_output(req_state, engine_core_output,
                                           engine_core_timestamp,
                                           iteration_stats)
403

404
            new_token_ids = engine_core_output.new_token_ids
405
            pooling_output = engine_core_output.pooling_output
406
            finish_reason = engine_core_output.finish_reason
407
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
408
            kv_transfer_params = engine_core_output.kv_transfer_params
409
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
410
            req_state.is_prefilling = False
411

412
413
414
415
416
417
418
419
420
421
422
423
424
425
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
                    new_token_ids, finish_reason == FinishReason.STOP)
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
                req_state.logprobs_processor.update_from_output(
                    engine_core_output)
426
427

            # 4) Create and handle RequestOutput objects.
428
            if request_output := req_state.make_request_output(
429
                    new_token_ids, pooling_output, finish_reason, stop_reason,
430
                    kv_transfer_params):
431
432
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
433
                    req_state.queue.put(request_output)
434
435
436
437
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

438
439
440
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
441
442
443
444
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
445
446
447
448
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
449

450
451
452
                # Track per-request stats
                self._update_stats_from_finished(req_state, finish_reason,
                                                 iteration_stats)
453
454
455
                if self.tracer:
                    self.do_tracing(engine_core_output, req_state,
                                    iteration_stats)
456
457
        self.lora_states.update_iteration_stats(iteration_stats)

458
459
460
461
462
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    def do_tracing(self, engine_core_output: EngineCoreOutput,
                   req_state: RequestState,
                   iteration_stats: Optional[IterationStats]) -> None:
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
        with (self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as span):
            metrics = req_state.stats
            e2e_time = iteration_stats.iteration_timestamp - \
                       metrics.arrival_time
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
                metrics.first_token_latency)
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
                               queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
                               len(req_state.prompt_token_ids))
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                               metrics.num_generation_tokens)
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL,
                prefill_time)
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE,
                decode_time)
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE,
                inference_time)

            # meta
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
                               req_state.request_id)
            if req_state.top_p:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
                                   req_state.top_p)
            if req_state.max_tokens_param:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
                                   req_state.max_tokens_param)
            if req_state.temperature:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
                                   req_state.temperature)
            if req_state.n:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
                                   req_state.n)

520
521
522
523
524
525
526
    def _update_stats_from_output(self, req_state: RequestState,
                                  engine_core_output: EngineCoreOutput,
                                  engine_core_timestamp: Optional[float],
                                  iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

527
528
        lora_stats = self.lora_states.get_stats(req_state)

529
530
531
532
533
534
        assert engine_core_timestamp is not None
        assert req_state.stats is not None
        iteration_stats.update_from_output(engine_core_output,
                                           engine_core_timestamp,
                                           req_state.is_prefilling,
                                           req_state.prompt_len,
535
                                           req_state.stats, lora_stats)
536
537
538
539
540
541
542
543
544

    def _update_stats_from_finished(self, req_state: RequestState,
                                    finish_reason: Optional[FinishReason],
                                    iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
545
546
547
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
            num_prompt_tokens=len(req_state.prompt_token_ids),
548
            max_tokens_param=req_state.max_tokens_param,
549
            req_stats=req_state.stats)
550
        self.lora_states.finish_request(req_state)
551
552
553
554

        ParentRequest.observe_finished_request(
            req_state.parent_req, iteration_stats,
            req_state.stats.num_generation_tokens)