"vllm/vscode:/vscode.git/clone" did not exist on "234a65b781d9dc51d28aebb208096baa8fe0458e"
request.py 9.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
7
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
8

9
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
10
from vllm.pooling_params import PoolingParams
11
from vllm.sampling_params import SamplingParams
12
from vllm.utils import is_list_of
13
14
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
                            EngineCoreRequest, FinishReason)
15
from vllm.v1.structured_output.request import StructuredOutputRequest
16
from vllm.v1.utils import ConstantList
17

18
if TYPE_CHECKING:
19
    from vllm.lora.request import LoRARequest
20
    from vllm.v1.core.kv_cache_utils import BlockHash
21

22
23
24
25
26
27

class Request:

    def __init__(
        self,
        request_id: str,
28
        prompt_token_ids: list[int],
29
        multi_modal_kwargs: Optional[list[MultiModalKwargsItem]],
30
        multi_modal_hashes: Optional[list[str]],
31
        multi_modal_placeholders: Optional[list[PlaceholderRange]],
32
33
        sampling_params: Optional[SamplingParams],
        pooling_params: Optional[PoolingParams],
34
        eos_token_id: Optional[int],
35
        client_index: int = 0,
36
        arrival_time: Optional[float] = None,
37
38
        lora_request: Optional["LoRARequest"] = None,
        structured_output_request: Optional["StructuredOutputRequest"] = None,
39
        cache_salt: Optional[str] = None,
40
        priority: int = 0,
41
42
        block_hasher: Optional[Callable[["Request"],
                                        list["BlockHash"]]] = None,
43
44
    ) -> None:
        self.request_id = request_id
45
        self.client_index = client_index
46
        self.priority = priority
47
        self.sampling_params = sampling_params
48
        self.pooling_params = pooling_params
49
50
51
        # Because of LoRA, the eos token id can be different for each request.
        self.eos_token_id = eos_token_id
        self.lora_request = lora_request
52
        self.structured_output_request = structured_output_request
53
54
        self.arrival_time = arrival_time if arrival_time is not None else \
            time.time()
55

56
        self.status = RequestStatus.WAITING
57
        self.use_structured_output = False
58
        self.events: list[EngineCoreEvent] = []
59
        self.stop_reason: Union[int, str, None] = None
60
61
62
63
64

        # P/D: Connector-specific KV transfer parameters.
        self.kv_transfer_params: Optional[dict[str, Any]] = None

        if pooling_params is not None:
65
            # Pooling models.
66
67
            self.max_tokens = 1
        elif sampling_params is not None:
68
            # Generative models.
69
70
71
72
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
            if sampling_params.guided_decoding is not None:
                self.status = RequestStatus.WAITING_FOR_FSM
73
                self.use_structured_output = True
74
75
76
77
78
79
80

            if sampling_params.extra_args is not None:
                self.kv_transfer_params = \
                    sampling_params.extra_args.get("kv_transfer_params")
        else:
            raise ValueError(
                "sampling_params and pooling_params can't both be unset")
81

82
        self.prompt_token_ids = prompt_token_ids
83
        self.num_prompt_tokens = len(self.prompt_token_ids)
84
85
        self._output_token_ids: list[int] = []
        self._all_token_ids: list[int] = self.prompt_token_ids.copy()
86
        self.num_output_placeholders = 0  # Used in async scheduling.
87
        self.spec_token_ids: list[int] = []
88
        self.num_computed_tokens = 0
89
        self.cache_salt: Optional[str] = cache_salt
90

91
92
        # Multi-modal related
        self.mm_positions = multi_modal_placeholders or []
93
        self.mm_kwargs = multi_modal_kwargs or []
94
        self.mm_hashes: list[str] = multi_modal_hashes or []
95
        self.num_encoder_inputs = len(self.mm_kwargs)
96
        self.has_encoder_inputs = self.num_encoder_inputs > 0
97

98
        # Sanity check
99
        assert len(self.mm_kwargs) == len(self.mm_positions)
100
        if self.mm_hashes:
101
            assert len(self.mm_kwargs) == len(self.mm_hashes)
102

103
        # Read-only views
omahs's avatar
omahs committed
104
        # Prevent directly appending to these lists since
105
106
107
108
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)

109
110
111
112
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

113
114
115
116
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

117
118
119
120
121
122
123
        self.block_hashes: list[BlockHash] = []
        self.get_hash_new_full_blocks: Optional[Callable[
            [], list[BlockHash]]] = None
        if block_hasher is not None:
            self.get_hash_new_full_blocks = partial(block_hasher, self)
            self.block_hashes = self.get_hash_new_full_blocks()

124
    @classmethod
125
126
127
128
    def from_engine_core_request(
        cls, request: EngineCoreRequest,
        block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
    ) -> "Request":
129
        if request.mm_kwargs is not None:
130
131
            mm_kwargs_lst = list(request.mm_kwargs)
            assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), (
132
                "mm_kwargs was not updated in EngineCore.add_request")
133
134
        else:
            mm_kwargs_lst = None
135

136
137
        return cls(
            request_id=request.request_id,
138
            client_index=request.client_index,
139
            prompt_token_ids=request.prompt_token_ids,
140
            multi_modal_kwargs=mm_kwargs_lst,
141
142
            multi_modal_hashes=request.mm_hashes,
            multi_modal_placeholders=request.mm_placeholders,
143
            sampling_params=request.sampling_params,
144
            pooling_params=request.pooling_params,
145
            eos_token_id=request.eos_token_id,
146
            arrival_time=request.arrival_time,
147
            lora_request=request.lora_request,
148
            structured_output_request=StructuredOutputRequest(
149
150
                sampling_params=request.sampling_params) \
                    if request.sampling_params else None,
151
            cache_salt=request.cache_salt,
152
            priority=request.priority,
153
            block_hasher=block_hasher,
154
155
        )

156
157
    def append_output_token_ids(
        self,
158
        token_ids: Union[int, list[int]],
159
160
    ) -> None:
        if isinstance(token_ids, int):
161
162
163
164
165
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
166

167
168
169
        if self.get_hash_new_full_blocks is not None:
            self.block_hashes.extend(self.get_hash_new_full_blocks())

170
171
172
173
    @property
    def is_output_corrupted(self) -> bool:
        return self.num_nans_in_logits > 0

174
175
    @property
    def num_tokens(self) -> int:
176
        return len(self._all_token_ids)
177

178
179
180
181
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

182
183
    @property
    def num_output_tokens(self) -> int:
184
        return len(self._output_token_ids)
185
186
187
188

    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

189
    def get_finished_reason(self) -> Union[FinishReason, None]:
190
191
        return RequestStatus.get_finished_reason(self.status)

192
193
    def get_num_encoder_tokens(self, input_id: int) -> int:
        assert input_id < len(self.mm_positions)
194
        num_tokens = self.mm_positions[input_id].length
195
196
        return num_tokens

197
198
199
200
201
202
203
204
205
206
207
208
209
    def record_event(
        self,
        event_type: EngineCoreEventType,
        timestamp: Optional[float] = None,
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

    def take_events(self) -> Optional[list[EngineCoreEvent]]:
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

210
211

class RequestStatus(enum.IntEnum):
212
    """Status of a request."""
213
214
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
215
    WAITING_FOR_REMOTE_KVS = enum.auto()
216
217
218
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
219
    # as a finished status.
220
221
222
223
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
224

225
226
227
    def __str__(self):
        return self.name

228
229
230
231
232
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
233
    def get_finished_reason(
234
            status: "RequestStatus") -> Union[FinishReason, None]:
235
236
237
238
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
239
# NOTE: The ignored requests are the requests whose prompt lengths
240
241
242
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
243
244
245
246
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
247
}