"vscode:/vscode.git/clone" did not exist on "acf8aeb79e23c32217dd37b5e96847302ae4d0b7"
output_processor.py 22.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, Optional, Union, cast
8

9
10
11
12
import torch

from vllm.outputs import (CompletionOutput, PoolingOutput,
                          PoolingRequestOutput, RequestOutput)
13
from vllm.sampling_params import RequestOutputKind
14
15
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
                          extract_trace_context)
16
from vllm.transformers_utils.tokenizer import AnyTokenizer
17
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
18
19
20
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
21
from vllm.v1.engine.parallel_sampling import ParentRequest
22
23
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
                                   RequestStateStats)
24
25


26
27
28
29
30
31
32
33
34
35
36
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
37
38
        self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
                                    Exception]] = None
39
40
        self.ready = asyncio.Event()

41
42
    def put(self, output: Union[RequestOutput, PoolingRequestOutput,
                                Exception]) -> None:
43
44
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
45
46
            self.output = output
            self.ready.set()
47
        elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
48
49
50
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
51

52
    async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
53
        """Get operation blocks on put event."""
54
55
56
57
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
58
59
        if isinstance(output, Exception):
            raise output
60
61
        return output

62
63
    def get_nowait(
            self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
64
        """Non-blocking get operation."""
65
66
67
68
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
69
70
        if isinstance(output, Exception):
            raise output
71
72
73
        return output


74
75
@dataclass
class OutputProcessorOutput:
76
    request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
77
    reqs_to_abort: list[str]
78
79
80
81
82
83
84


class RequestState:

    def __init__(
        self,
        request_id: str,
85
86
        parent_req: Optional[ParentRequest],
        request_index: int,
87
        lora_name: Optional[str],
88
        output_kind: RequestOutputKind,
89
        prompt: Optional[str],
90
        prompt_token_ids: list[int],
91
92
        logprobs_processor: Optional[LogprobsProcessor],
        detokenizer: Optional[IncrementalDetokenizer],
93
        max_tokens_param: Optional[int],
94
        arrival_time: float,
95
        queue: Optional[RequestOutputCollector],
96
        log_stats: bool,
97
98
99
        top_p: Optional[float] = None,
        n: Optional[int] = None,
        temperature: Optional[float] = None,
100
101
    ):
        self.request_id = request_id
102
103
        self.parent_req = parent_req
        self.request_index = request_index
104
        self.lora_name = lora_name
105
        self.output_kind = output_kind
106
107
108
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.prompt_len = len(prompt_token_ids)
109
        self.logprobs_processor = logprobs_processor
110
        self.detokenizer = detokenizer
111
        self.max_tokens_param = max_tokens_param
112
113
114
        self.top_p = top_p
        self.n = n
        self.temperature = temperature
115
116
        self.is_prefilling = True
        self.queue = queue
117
        self.num_cached_tokens = 0
118

119
120
        self.stats = RequestStateStats(
            arrival_time=arrival_time) if log_stats else None
121

122
123
124
125
126
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
127
        prompt: Optional[str],
128
129
        parent_req: Optional[ParentRequest],
        request_index: int,
130
        queue: Optional[RequestOutputCollector],
131
        log_stats: bool,
132
    ) -> "RequestState":
133
134
135
136
137
138
139
140
141
142
143
144
145
146

        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
147
148
149
            top_p = sampling_params.top_p
            n = sampling_params.n
            temperature = sampling_params.temperature
150
151
152
153
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
154
155
156
            top_p = None
            n = None
            temperature = None
157
158
159
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

160
161
        return cls(
            request_id=request.request_id,
162
163
            parent_req=parent_req,
            request_index=request_index,
164
165
            lora_name=(request.lora_request.name
                       if request.lora_request is not None else None),
166
            output_kind=output_kind,
167
            prompt=prompt,
168
            prompt_token_ids=request.prompt_token_ids,
169
170
171
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
172
173
174
            top_p=top_p,
            n=n,
            temperature=temperature,
175
            arrival_time=request.arrival_time,
176
            queue=queue,
177
            log_stats=log_stats,
178
179
        )

180
181
182
    def make_request_output(
        self,
        new_token_ids: list[int],
183
        pooling_output: Optional[torch.Tensor],
184
185
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
Robert Shaw's avatar
Robert Shaw committed
186
        kv_transfer_params: Optional[dict[str, Any]] = None,
187
    ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
188
189

        finished = finish_reason is not None
190
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
191

192
        if not finished and final_only:
193
194
195
            # Only the final output is required in FINAL_ONLY mode.
            return None

196
        request_id = self.request_id
197
198
199
200
201
202
203
204
        if pooling_output is not None:
            return self._new_request_output(
                request_id, [self._new_pooling_output(pooling_output)],
                finished)

        output = self._new_completion_output(new_token_ids, finish_reason,
                                             stop_reason)

205
        if self.parent_req is None:
206
            outputs = [output]
207
208
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
209
                request_id, output)
210
211
            if not outputs:
                return None
212

Robert Shaw's avatar
Robert Shaw committed
213
        return self._new_request_output(request_id, outputs, finished,
214
                                        kv_transfer_params)
215
216
217
218

    def _new_request_output(
        self,
        request_id: str,
219
        outputs: Union[list[CompletionOutput], list[PoolingOutput]],
220
        finished: bool,
Robert Shaw's avatar
Robert Shaw committed
221
        kv_transfer_params: Optional[dict[str, Any]] = None,
222
223
    ) -> Union[RequestOutput, PoolingRequestOutput]:

224
225
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
226
227
228
            assert len(outputs) == 1
            return PoolingRequestOutput(
                request_id=request_id,
229
                outputs=first_output,
230
231
232
233
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
234
235
236
237
238
239
240
241
242
243
244
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=self.prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
245
            outputs=cast(list[CompletionOutput], outputs),
246
            finished=finished,
Robert Shaw's avatar
Robert Shaw committed
247
            kv_transfer_params=kv_transfer_params,
248
            num_cached_tokens=self.num_cached_tokens,
249
250
251
252
253
254
255
256
257
        )

    def _new_completion_output(
        self,
        token_ids: list[int],
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
    ) -> CompletionOutput:

258
259
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
            logprobs = logprobs[-len(token_ids):]

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
            stop_reason=stop_reason if finished else None)

282
283
284
285
286
287
288
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:

        return PoolingOutput(data=pooling_output)

289
290
291
292

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

293
    def __init__(self, tokenizer: TokenizerGroup, log_stats: bool):
294
295
        self.log_stats = log_stats
        self.tokenizer = tokenizer
296
        self.request_states: dict[str, RequestState] = {}
297
        self.parent_requests: dict[str, ParentRequest] = {}
298
        self.lora_states = LoRARequestStates()
299
        self.tracer: Optional[Tracer] = None
300
301
302
303
304
305
306

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

307
308
309
310
311
312
313
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

314
315
    def abort_requests(
        self,
316
317
318
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
319
        for request_id in request_ids:
320
321
322
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
                self.lora_states.abort_request(req_state)
323
                request_ids_to_abort.append(request_id)
324
325
326
327
328
329
330
331
332
333
334
335
                # Produce final abort output.
                if req_state.queue is not None and (
                        request_output := req_state.make_request_output(
                            [], None, FinishReason.ABORT, None, None)):
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
336
        return request_ids_to_abort
337
338
339
340

    def add_request(
        self,
        request: EngineCoreRequest,
341
        prompt: Optional[str],
342
343
        parent_req: Optional[ParentRequest] = None,
        request_index: int = 0,
344
        queue: Optional[RequestOutputCollector] = None,
345
346
347
348
349
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

350
351
352
353
354
355
356
357
358
359
        tokenizer = None if not self.tokenizer else \
            self.tokenizer.get_lora_tokenizer(request.lora_request)

        req_state = RequestState.from_new_request(tokenizer=tokenizer,
                                                  request=request,
                                                  prompt=prompt,
                                                  parent_req=parent_req,
                                                  request_index=request_index,
                                                  queue=queue,
                                                  log_stats=self.log_stats)
360
361
        self.request_states[request_id] = req_state
        self.lora_states.add_request(req_state)
362
363
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
364
365
366

    def process_outputs(
        self,
367
        engine_core_outputs: list[EngineCoreOutput],
368
        engine_core_timestamp: Optional[float] = None,
369
        iteration_stats: Optional[IterationStats] = None,
370
371
372
373
374
375
376
377
378
379
380
381
382
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
            * If there is a queue (for usage with AsyncLLM), 
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

            * If there is no queue (for usage with LLMEngine), 
              return a list of RequestOutput objects.

383
        NOTE FOR DEVELOPERS
384

385
        vLLM V1 minimizes the number of python loops over the full
386
387
388
        batch to ensure system overheads are minimized. This is the 
        only function that should loop over EngineCoreOutputs.

389
390
        If you need to touch every element of the batch, do it from
        within the loop below.
391
392
        """

393
394
        request_outputs: Union[list[RequestOutput],
                               list[PoolingRequestOutput]] = []
395
        reqs_to_abort: list[str] = []
396
397
398
399
400
401
402
403
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
404
405
406
            self._update_stats_from_output(req_state, engine_core_output,
                                           engine_core_timestamp,
                                           iteration_stats)
407

408
            new_token_ids = engine_core_output.new_token_ids
409
            pooling_output = engine_core_output.pooling_output
410
            finish_reason = engine_core_output.finish_reason
411
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
412
            kv_transfer_params = engine_core_output.kv_transfer_params
413
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
414
            req_state.is_prefilling = False
415

416
417
418
419
420
421
422
423
424
425
426
427
428
429
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
                    new_token_ids, finish_reason == FinishReason.STOP)
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
                req_state.logprobs_processor.update_from_output(
                    engine_core_output)
430
431

            # 4) Create and handle RequestOutput objects.
432
            if request_output := req_state.make_request_output(
433
                    new_token_ids, pooling_output, finish_reason, stop_reason,
434
                    kv_transfer_params):
435
436
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
437
                    req_state.queue.put(request_output)
438
439
440
441
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

442
443
444
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
445
446
447
448
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
449
450
451
452
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
453

454
455
456
                # Track per-request stats
                self._update_stats_from_finished(req_state, finish_reason,
                                                 iteration_stats)
457
458
459
                if self.tracer:
                    self.do_tracing(engine_core_output, req_state,
                                    iteration_stats)
460
461
        self.lora_states.update_iteration_stats(iteration_stats)

462
463
464
465
466
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    def do_tracing(self, engine_core_output: EngineCoreOutput,
                   req_state: RequestState,
                   iteration_stats: Optional[IterationStats]) -> None:
        assert req_state.stats is not None
        assert iteration_stats is not None
        assert self.tracer is not None

        arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
        trace_context = extract_trace_context(engine_core_output.trace_headers)
        with (self.tracer.start_as_current_span(
                "llm_request",
                kind=SpanKind.SERVER,
                context=trace_context,
                start_time=arrival_time_nano_seconds) as span):
            metrics = req_state.stats
            e2e_time = iteration_stats.iteration_timestamp - \
                       metrics.arrival_time
            queued_time = metrics.scheduled_ts - metrics.queued_ts
            prefill_time = metrics.first_token_ts - metrics.scheduled_ts
            decode_time = metrics.last_token_ts - metrics.first_token_ts
            inference_time = metrics.last_token_ts - metrics.scheduled_ts
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
                metrics.first_token_latency)
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
            span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
                               queued_time)
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
                               len(req_state.prompt_token_ids))
            span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
                               metrics.num_generation_tokens)
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL,
                prefill_time)
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE,
                decode_time)
            span.set_attribute(
                SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE,
                inference_time)

            # meta
            span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
                               req_state.request_id)
            if req_state.top_p:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
                                   req_state.top_p)
            if req_state.max_tokens_param:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
                                   req_state.max_tokens_param)
            if req_state.temperature:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
                                   req_state.temperature)
            if req_state.n:
                span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
                                   req_state.n)

524
525
526
527
528
529
530
    def _update_stats_from_output(self, req_state: RequestState,
                                  engine_core_output: EngineCoreOutput,
                                  engine_core_timestamp: Optional[float],
                                  iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

531
532
        lora_stats = self.lora_states.get_stats(req_state)

533
534
535
536
537
538
        assert engine_core_timestamp is not None
        assert req_state.stats is not None
        iteration_stats.update_from_output(engine_core_output,
                                           engine_core_timestamp,
                                           req_state.is_prefilling,
                                           req_state.prompt_len,
539
                                           req_state.stats, lora_stats)
540
541
542
543
544
545
546
547
548

    def _update_stats_from_finished(self, req_state: RequestState,
                                    finish_reason: Optional[FinishReason],
                                    iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
549
550
551
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
            num_prompt_tokens=len(req_state.prompt_token_ids),
552
            max_tokens_param=req_state.max_tokens_param,
553
            req_stats=req_state.stats)
554
        self.lora_states.finish_request(req_state)
555
556
557
558

        ParentRequest.observe_finished_request(
            req_state.parent_req, iteration_stats,
            req_state.stats.num_generation_tokens)