output_processor.py 18.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from collections.abc import Iterable
6
from dataclasses import dataclass
7
from typing import Any, Optional, Union, cast
8

9
10
11
12
import torch

from vllm.outputs import (CompletionOutput, PoolingOutput,
                          PoolingRequestOutput, RequestOutput)
13
14
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
15
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
16
17
18
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
19
from vllm.v1.engine.parallel_sampling import ParentRequest
20
21
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
                                   RequestStateStats)
22
23


24
25
26
27
28
29
30
31
32
33
34
class RequestOutputCollector:
    """
    Collects streamed RequestOutputs per individual request,
    for hand-off to the consuming asyncio generate task.

    When streaming deltas, RequestOutputs are merged if the
    producer gets ahead of the consumer.
    """

    def __init__(self, output_kind: RequestOutputKind):
        self.aggregate = output_kind == RequestOutputKind.DELTA
35
36
        self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
                                    Exception]] = None
37
38
        self.ready = asyncio.Event()

39
40
    def put(self, output: Union[RequestOutput, PoolingRequestOutput,
                                Exception]) -> None:
41
42
        """Non-blocking put operation."""
        if self.output is None or isinstance(output, Exception):
43
44
            self.output = output
            self.ready.set()
45
        elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
46
47
48
            # This ensures that request outputs with different request indexes
            # (if n > 1) do not override each other.
            self.output.add(output, aggregate=self.aggregate)
49

50
    async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
51
        """Get operation blocks on put event."""
52
53
54
55
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
56
57
        if isinstance(output, Exception):
            raise output
58
59
        return output

60
61
    def get_nowait(
            self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
62
        """Non-blocking get operation."""
63
64
65
66
        output = self.output
        if output is not None:
            self.output = None
            self.ready.clear()
67
68
        if isinstance(output, Exception):
            raise output
69
70
71
        return output


72
73
74
@dataclass
class OutputProcessorOutput:

75
    request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
76
    reqs_to_abort: list[str]
77
78
79
80
81
82
83


class RequestState:

    def __init__(
        self,
        request_id: str,
84
85
        parent_req: Optional[ParentRequest],
        request_index: int,
86
        lora_name: Optional[str],
87
        output_kind: RequestOutputKind,
88
        prompt: Optional[str],
89
        prompt_token_ids: list[int],
90
91
        logprobs_processor: Optional[LogprobsProcessor],
        detokenizer: Optional[IncrementalDetokenizer],
92
        max_tokens_param: Optional[int],
93
        arrival_time: float,
94
        queue: Optional[RequestOutputCollector],
95
        log_stats: bool,
96
97
    ):
        self.request_id = request_id
98
99
        self.parent_req = parent_req
        self.request_index = request_index
100
        self.lora_name = lora_name
101
        self.output_kind = output_kind
102
103
104
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.prompt_len = len(prompt_token_ids)
105
        self.logprobs_processor = logprobs_processor
106
        self.detokenizer = detokenizer
107
        self.max_tokens_param = max_tokens_param
108
109
        self.is_prefilling = True
        self.queue = queue
110
        self.num_cached_tokens = 0
111

112
113
        self.stats = RequestStateStats(
            arrival_time=arrival_time) if log_stats else None
114

115
116
117
118
119
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
120
        prompt: Optional[str],
121
122
        parent_req: Optional[ParentRequest],
        request_index: int,
123
        queue: Optional[RequestOutputCollector],
124
        log_stats: bool,
125
    ) -> "RequestState":
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

        if sampling_params := request.sampling_params:
            if not sampling_params.detokenize:
                tokenizer = None
            output_kind = sampling_params.output_kind
            logprobs_processor = LogprobsProcessor.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            detokenizer = IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            )
            max_tokens_param = sampling_params.max_tokens
        else:
            logprobs_processor = None
            detokenizer = None
            max_tokens_param = None
            assert request.pooling_params is not None
            output_kind = request.pooling_params.output_kind

147
148
        return cls(
            request_id=request.request_id,
149
150
            parent_req=parent_req,
            request_index=request_index,
151
152
            lora_name=(request.lora_request.name
                       if request.lora_request is not None else None),
153
            output_kind=output_kind,
154
            prompt=prompt,
155
            prompt_token_ids=request.prompt_token_ids,
156
157
158
            logprobs_processor=logprobs_processor,
            detokenizer=detokenizer,
            max_tokens_param=max_tokens_param,
159
            arrival_time=request.arrival_time,
160
            queue=queue,
161
            log_stats=log_stats,
162
163
        )

164
165
166
    def make_request_output(
        self,
        new_token_ids: list[int],
167
        pooling_output: Optional[torch.Tensor],
168
169
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
Robert Shaw's avatar
Robert Shaw committed
170
        kv_transfer_params: Optional[dict[str, Any]] = None,
171
    ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
172
173

        finished = finish_reason is not None
174
        final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
175

176
        if not finished and final_only:
177
178
179
            # Only the final output is required in FINAL_ONLY mode.
            return None

180
        request_id = self.request_id
181
182
183
184
185
186
187
188
        if pooling_output is not None:
            return self._new_request_output(
                request_id, [self._new_pooling_output(pooling_output)],
                finished)

        output = self._new_completion_output(new_token_ids, finish_reason,
                                             stop_reason)

189
        if self.parent_req is None:
190
            outputs = [output]
191
192
        else:
            request_id, outputs, finished = self.parent_req.get_outputs(
193
                request_id, output)
194
195
            if not outputs:
                return None
196

Robert Shaw's avatar
Robert Shaw committed
197
        return self._new_request_output(request_id, outputs, finished,
198
                                        kv_transfer_params)
199
200
201
202

    def _new_request_output(
        self,
        request_id: str,
203
        outputs: Union[list[CompletionOutput], list[PoolingOutput]],
204
        finished: bool,
Robert Shaw's avatar
Robert Shaw committed
205
        kv_transfer_params: Optional[dict[str, Any]] = None,
206
207
    ) -> Union[RequestOutput, PoolingRequestOutput]:

208
209
        first_output = outputs[0]
        if isinstance(first_output, PoolingOutput):
210
211
212
            assert len(outputs) == 1
            return PoolingRequestOutput(
                request_id=request_id,
213
                outputs=first_output,
214
215
216
217
                prompt_token_ids=self.prompt_token_ids,
                finished=finished,
            )
        assert self.logprobs_processor is not None
218
219
220
221
222
223
224
225
226
227
228
        if self.output_kind == RequestOutputKind.DELTA:
            # Side effect: logprobs processor forgets prompt logprobs
            prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
        else:
            prompt_logprobs = self.logprobs_processor.prompt_logprobs

        return RequestOutput(
            request_id=request_id,
            prompt=self.prompt,
            prompt_token_ids=self.prompt_token_ids,
            prompt_logprobs=prompt_logprobs,
229
            outputs=cast(list[CompletionOutput], outputs),
230
            finished=finished,
Robert Shaw's avatar
Robert Shaw committed
231
            kv_transfer_params=kv_transfer_params,
232
            num_cached_tokens=self.num_cached_tokens,
233
234
235
236
237
238
239
240
241
        )

    def _new_completion_output(
        self,
        token_ids: list[int],
        finish_reason: Optional[FinishReason],
        stop_reason: Union[int, str, None],
    ) -> CompletionOutput:

242
243
        assert self.detokenizer is not None
        assert self.logprobs_processor is not None
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        finished = finish_reason is not None
        delta = self.output_kind == RequestOutputKind.DELTA

        # Prepare text and token_ids, based on delta mode
        text = self.detokenizer.get_next_output_text(finished, delta)
        if not delta:
            token_ids = self.detokenizer.output_token_ids

        # Prepare logprobs, based on delta mode
        logprobs = self.logprobs_processor.logprobs
        if delta and logprobs:
            logprobs = logprobs[-len(token_ids):]

        return CompletionOutput(
            index=self.request_index,
            text=text,
            token_ids=token_ids,
            logprobs=logprobs,
            cumulative_logprob=self.logprobs_processor.cumulative_logprob,
            finish_reason=str(finish_reason) if finished else None,
            stop_reason=stop_reason if finished else None)

266
267
268
269
270
271
272
    def _new_pooling_output(
        self,
        pooling_output: torch.Tensor,
    ) -> PoolingOutput:

        return PoolingOutput(data=pooling_output)

273
274
275
276
277
278

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

    def __init__(
        self,
279
        tokenizer: TokenizerGroup,
280
281
282
283
        log_stats: bool,
    ):
        self.log_stats = log_stats
        self.tokenizer = tokenizer
284
        self.request_states: dict[str, RequestState] = {}
285
        self.parent_requests: dict[str, ParentRequest] = {}
286
        self.lora_states = LoRARequestStates()
287
288
289
290
291
292
293

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

294
295
296
297
298
299
300
    def propagate_error(self, e: Exception):
        """Propagate error to all generate() tasks."""

        for _, state in self.request_states.items():
            assert state.queue is not None
            state.queue.put(e)

301
302
    def abort_requests(
        self,
303
304
305
        request_ids: Iterable[str],
    ) -> list[str]:
        request_ids_to_abort = []
306
        for request_id in request_ids:
307
308
309
            req_state = self.request_states.pop(request_id, None)
            if req_state is not None:
                self.lora_states.abort_request(req_state)
310
                request_ids_to_abort.append(request_id)
311
312
313
314
315
316
317
318
319
320
321
322
                # Produce final abort output.
                if req_state.queue is not None and (
                        request_output := req_state.make_request_output(
                            [], None, FinishReason.ABORT, None, None)):
                    req_state.queue.put(request_output)
            elif parent := self.parent_requests.get(request_id):
                # Abort children prior to removing the parent.
                if parent.child_requests:
                    child_reqs = list(parent.child_requests)
                    child_reqs = self.abort_requests(child_reqs)
                    request_ids_to_abort.extend(child_reqs)
                self.parent_requests.pop(request_id, None)
323
        return request_ids_to_abort
324
325
326
327

    def add_request(
        self,
        request: EngineCoreRequest,
328
        prompt: Optional[str],
329
330
        parent_req: Optional[ParentRequest] = None,
        request_index: int = 0,
331
        queue: Optional[RequestOutputCollector] = None,
332
333
334
335
336
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

337
338
339
340
341
342
343
344
345
346
        tokenizer = None if not self.tokenizer else \
            self.tokenizer.get_lora_tokenizer(request.lora_request)

        req_state = RequestState.from_new_request(tokenizer=tokenizer,
                                                  request=request,
                                                  prompt=prompt,
                                                  parent_req=parent_req,
                                                  request_index=request_index,
                                                  queue=queue,
                                                  log_stats=self.log_stats)
347
348
        self.request_states[request_id] = req_state
        self.lora_states.add_request(req_state)
349
350
        if parent_req:
            self.parent_requests[parent_req.request_id] = parent_req
351
352
353

    def process_outputs(
        self,
354
        engine_core_outputs: list[EngineCoreOutput],
355
        engine_core_timestamp: Optional[float] = None,
356
        iteration_stats: Optional[IterationStats] = None,
357
358
359
360
361
362
363
364
365
366
367
368
369
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
            * If there is a queue (for usage with AsyncLLM), 
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

            * If there is no queue (for usage with LLMEngine), 
              return a list of RequestOutput objects.

370
        NOTE FOR DEVELOPERS
371

372
        vLLM V1 minimizes the number of python loops over the full
373
374
375
        batch to ensure system overheads are minimized. This is the 
        only function that should loop over EngineCoreOutputs.

376
377
        If you need to touch every element of the batch, do it from
        within the loop below.
378
379
        """

380
381
        request_outputs: Union[list[RequestOutput],
                               list[PoolingRequestOutput]] = []
382
        reqs_to_abort: list[str] = []
383
384
385
386
387
388
389
390
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
391
392
393
            self._update_stats_from_output(req_state, engine_core_output,
                                           engine_core_timestamp,
                                           iteration_stats)
394

395
            new_token_ids = engine_core_output.new_token_ids
396
            pooling_output = engine_core_output.pooling_output
397
            finish_reason = engine_core_output.finish_reason
398
            stop_reason = engine_core_output.stop_reason
Robert Shaw's avatar
Robert Shaw committed
399
            kv_transfer_params = engine_core_output.kv_transfer_params
400
            req_state.num_cached_tokens = engine_core_output.num_cached_tokens
401
            req_state.is_prefilling = False
402

403
404
405
406
407
408
409
410
411
412
413
414
415
416
            if pooling_output is None:
                assert req_state.detokenizer is not None
                assert req_state.logprobs_processor is not None
                # 2) Detokenize the token ids into text and perform stop checks.
                stop_string = req_state.detokenizer.update(
                    new_token_ids, finish_reason == FinishReason.STOP)
                if stop_string:
                    finish_reason = FinishReason.STOP
                    stop_reason = stop_string

                # 3) Compute sample and prompt logprobs for request,
                # if required.
                req_state.logprobs_processor.update_from_output(
                    engine_core_output)
417
418

            # 4) Create and handle RequestOutput objects.
419
            if request_output := req_state.make_request_output(
420
                    new_token_ids, pooling_output, finish_reason, stop_reason,
421
                    kv_transfer_params):
422
423
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
424
                    req_state.queue.put(request_output)
425
426
427
428
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

429
430
431
            # Free completed requests.
            if finish_reason is not None:
                self.request_states.pop(req_id)
432
433
434
435
                # Remove parent request if applicable.
                parent_req = req_state.parent_req
                if parent_req and not parent_req.child_requests:
                    self.parent_requests.pop(parent_req.request_id, None)
436
437
438
439
                if not engine_core_output.finished:
                    # If req not finished in EngineCore, but Detokenizer
                    # detected stop string, abort needed in EngineCore.
                    reqs_to_abort.append(req_id)
440

441
442
443
                # Track per-request stats
                self._update_stats_from_finished(req_state, finish_reason,
                                                 iteration_stats)
444

445
446
        self.lora_states.update_iteration_stats(iteration_stats)

447
448
449
450
451
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
        )

452
453
454
455
456
457
458
    def _update_stats_from_output(self, req_state: RequestState,
                                  engine_core_output: EngineCoreOutput,
                                  engine_core_timestamp: Optional[float],
                                  iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

459
460
        lora_stats = self.lora_states.get_stats(req_state)

461
462
463
464
465
466
        assert engine_core_timestamp is not None
        assert req_state.stats is not None
        iteration_stats.update_from_output(engine_core_output,
                                           engine_core_timestamp,
                                           req_state.is_prefilling,
                                           req_state.prompt_len,
467
                                           req_state.stats, lora_stats)
468
469
470
471
472
473
474
475
476

    def _update_stats_from_finished(self, req_state: RequestState,
                                    finish_reason: Optional[FinishReason],
                                    iteration_stats: Optional[IterationStats]):
        if iteration_stats is None:
            return

        assert finish_reason is not None
        assert req_state.stats is not None
477
478
479
        iteration_stats.update_from_finished_request(
            finish_reason=finish_reason,
            num_prompt_tokens=len(req_state.prompt_token_ids),
480
            max_tokens_param=req_state.max_tokens_param,
481
            req_stats=req_state.stats)
482
        self.lora_states.finish_request(req_state)
483
484
485
486

        ParentRequest.observe_finished_request(
            req_state.parent_req, iteration_stats,
            req_state.stats.num_generation_tokens)