output_processor.py 7.58 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
import asyncio
from dataclasses import dataclass
from typing import Dict, List, Optional

from vllm.outputs import RequestOutput
from vllm.transformers_utils.detokenizer_utils import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.engine.detokenizer import (DetokenizerOutput,
                                        IncrementalDetokenizer)
13
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


@dataclass
class OutputProcessorOutput:

    request_outputs: List[RequestOutput]
    reqs_to_abort: List[str]
    iteration_stats: IterationStats


class RequestState:

    def __init__(
        self,
        request_id: str,
        prompt: Optional[str],
        prompt_token_ids: List[int],
        detokenizer: IncrementalDetokenizer,
32
        arrival_time: float,
33
34
35
36
37
38
39
40
41
42
        queue: Optional[asyncio.Queue[RequestOutput]],
    ):
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.prompt_len = len(prompt_token_ids)
        self.detokenizer = detokenizer
        self.is_prefilling = True
        self.queue = queue

43
        self.stats = RequestStateStats(last_token_time=arrival_time)
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    @classmethod
    def from_new_request(
        cls,
        tokenizer: AnyTokenizer,
        request: EngineCoreRequest,
        queue: Optional[asyncio.Queue[RequestOutput]] = None,
    ) -> "RequestState":
        return cls(
            request_id=request.request_id,
            prompt=request.prompt,
            prompt_token_ids=request.prompt_token_ids,
            detokenizer=IncrementalDetokenizer.from_new_request(
                tokenizer=tokenizer,
                request=request,
            ),
60
            arrival_time=request.arrival_time,
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            queue=queue,
        )


class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

    def __init__(
        self,
        tokenizer: BaseTokenizerGroup,
        log_stats: bool,
    ):
        self.log_stats = log_stats
        self.tokenizer = tokenizer
        self.request_states: Dict[str, RequestState] = {}

    def is_request_active(self, request_id: str) -> bool:
        return request_id in self.request_states

    def get_num_unfinished_requests(self):
        return len(self.request_states)

    def has_unfinished_requests(self) -> bool:
        return len(self.request_states) > 0

    def abort_requests(
        self,
        request_ids: List[str],
    ) -> None:
        for request_id in request_ids:
            self.request_states.pop(request_id, None)

    def add_request(
        self,
        request: EngineCoreRequest,
        queue: Optional[asyncio.Queue[RequestOutput]] = None,
    ) -> None:
        request_id = request.request_id
        if request_id in self.request_states:
            raise ValueError(f"Request id {request_id} already running.")

        self.request_states[request_id] = RequestState.from_new_request(
            tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
            request=request,
            queue=queue)

    def process_outputs(
        self,
        engine_core_outputs: List[EngineCoreOutput],
110
        iteration_stats: Optional[IterationStats] = None,
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects:
            * If there is a queue (for usage with AsyncLLM), 
              put the RequestOutput objects into the queue for
              handling by the per-request generate() tasks.

            * If there is no queue (for usage with LLMEngine), 
              return a list of RequestOutput objects.

        ****************** NOTE FOR DEVELOPERS ******************

        VLLM V1 minimizes the number of python loops over the full
        batch to ensure system overheads are minimized. This is the 
        only function that should loop over EngineCoreOutputs.

        If you need to touch every element of the batch, implement a
        method called XXXClass.update_from_output() to be called
        within the loop below. For examples, see:
            * IterationStats.update_from_output()
            * Detokenizer.update_from_output()
        
        TODO(rob): add Protocol makes update_from_output explicit.
        
        **********************************************************
        """

        request_outputs: List[RequestOutput] = []
        reqs_to_abort: List[str] = []
143
144
        if not iteration_stats:
            iteration_stats = IterationStats(self.log_stats)
145
146
147
148
149
150
151
152
153
154
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
            iteration_stats.update_from_output(engine_core_output,
                                               req_state.is_prefilling,
155
156
                                               req_state.prompt_len,
                                               req_state.stats)
157
158
159
160
161
162
163
            req_state.is_prefilling = False

            # 2) Detokenize the token ids into text.
            detokenizer_output = req_state.detokenizer.update_from_output(
                engine_core_output)

            # 3) Create and handle RequestOutput objects.
164
165
166
167
            if detokenizer_output is not None:
                request_output = self._make_request_output(
                    req_state, detokenizer_output)

168
169
170
171
172
173
174
175
176
                if req_state.queue is not None:
                    # AsyncLLM: put into queue for handling by generate().
                    req_state.queue.put_nowait(request_output)
                else:
                    # LLMEngine: return list of RequestOutputs.
                    request_outputs.append(request_output)

                # Free completed requests.
                if request_output.finished:
177
178
                    assert detokenizer_output.finish_reason is not None

179
180
181
182
183
184
                    self.request_states.pop(req_id)
                    if not engine_core_output.finished:
                        # If req not finished in EngineCore, but Detokenizer
                        # detected stop string, abort needed in EngineCore.
                        reqs_to_abort.append(req_id)

185
186
                    # Track per-request stats
                    iteration_stats.update_from_finished_request(
187
188
                        detokenizer_output.finish_reason, request_output,
                        req_state.stats)
189

190
191
192
193
194
195
        return OutputProcessorOutput(
            request_outputs=request_outputs,
            reqs_to_abort=reqs_to_abort,
            iteration_stats=iteration_stats,
        )

196
    @staticmethod
197
198
    def _make_request_output(
        request_state: RequestState,
199
200
        detokenizer_output: DetokenizerOutput,
    ) -> RequestOutput:
201
202
203
204
205
206
207
208
209
210
        request_output = RequestOutput.new(
            request_state.request_id,
            request_state.prompt,
            request_state.prompt_token_ids,
            detokenizer_output.output_text,
            detokenizer_output.token_ids,
            detokenizer_output.finished,
        )
        if detokenizer_output.finished:
            completion_output = request_output.outputs[0]
211
212
            completion_output.finish_reason = str(
                detokenizer_output.finish_reason)
213
214
215
            completion_output.stop_reason = detokenizer_output.stop_reason

        return request_output