"vscode:/vscode.git/clone" did not exist on "4572a06afe96d0a6d5d3efacf130c71505dd2bc9"
request.py 8.59 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
7
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
8

9
from vllm.multimodal.inputs import MultiModalFeatureSpec
10
from vllm.pooling_params import PoolingParams
11
from vllm.sampling_params import SamplingParams
12
13
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
                            EngineCoreRequest, FinishReason)
14
from vllm.v1.structured_output.request import StructuredOutputRequest
15
from vllm.v1.utils import ConstantList
16

17
if TYPE_CHECKING:
18
    from vllm.lora.request import LoRARequest
19
    from vllm.v1.core.kv_cache_utils import BlockHash
20

21
22
23
24
25
26

class Request:

    def __init__(
        self,
        request_id: str,
27
        prompt_token_ids: list[int],
28
29
        sampling_params: Optional[SamplingParams],
        pooling_params: Optional[PoolingParams],
30
        eos_token_id: Optional[int],
31
        client_index: int = 0,
32
        arrival_time: Optional[float] = None,
33
        mm_features: Optional[list[MultiModalFeatureSpec]] = None,
34
35
        lora_request: Optional["LoRARequest"] = None,
        structured_output_request: Optional["StructuredOutputRequest"] = None,
36
        cache_salt: Optional[str] = None,
37
        priority: int = 0,
38
39
        block_hasher: Optional[Callable[["Request"],
                                        list["BlockHash"]]] = None,
40
41
    ) -> None:
        self.request_id = request_id
42
        self.client_index = client_index
43
        self.priority = priority
44
        self.sampling_params = sampling_params
45
        self.pooling_params = pooling_params
46
47
48
        # Because of LoRA, the eos token id can be different for each request.
        self.eos_token_id = eos_token_id
        self.lora_request = lora_request
49
        self.structured_output_request = structured_output_request
50
51
        self.arrival_time = arrival_time if arrival_time is not None else \
            time.time()
52

53
        self.status = RequestStatus.WAITING
54
        self.use_structured_output = False
55
        self.events: list[EngineCoreEvent] = []
56
        self.stop_reason: Union[int, str, None] = None
57
58
59
60
61

        # P/D: Connector-specific KV transfer parameters.
        self.kv_transfer_params: Optional[dict[str, Any]] = None

        if pooling_params is not None:
62
            # Pooling models.
63
64
            self.max_tokens = 1
        elif sampling_params is not None:
65
            # Generative models.
66
67
68
69
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
            if sampling_params.guided_decoding is not None:
                self.status = RequestStatus.WAITING_FOR_FSM
70
                self.use_structured_output = True
71
72
73
74
75
76
77

            if sampling_params.extra_args is not None:
                self.kv_transfer_params = \
                    sampling_params.extra_args.get("kv_transfer_params")
        else:
            raise ValueError(
                "sampling_params and pooling_params can't both be unset")
78

79
        self.prompt_token_ids = prompt_token_ids
80
        self.num_prompt_tokens = len(self.prompt_token_ids)
81
82
        self._output_token_ids: list[int] = []
        self._all_token_ids: list[int] = self.prompt_token_ids.copy()
83
        self.num_output_placeholders = 0  # Used in async scheduling.
84
        self.spec_token_ids: list[int] = []
85
        self.num_computed_tokens = 0
86
        self.cache_salt: Optional[str] = cache_salt
87

88
        # Multi-modal related
89
90
        self.mm_features = mm_features or []
        self.num_encoder_inputs = len(self.mm_features)
91
        self.has_encoder_inputs = self.num_encoder_inputs > 0
92
93
94
95
96
        # TODO(sfeng33): Remove these legacy fields after clearing out all
        # references in scheduler and model runner
        self.mm_positions = [f.mm_position for f in self.mm_features]
        self.mm_kwargs = [f.data for f in self.mm_features]
        self.mm_hashes = [f.identifier for f in self.mm_features]
97

98
        # Read-only views
omahs's avatar
omahs committed
99
        # Prevent directly appending to these lists since
100
101
102
103
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)

104
105
106
107
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

108
109
110
111
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

112
113
114
115
116
117
118
        self.block_hashes: list[BlockHash] = []
        self.get_hash_new_full_blocks: Optional[Callable[
            [], list[BlockHash]]] = None
        if block_hasher is not None:
            self.get_hash_new_full_blocks = partial(block_hasher, self)
            self.block_hashes = self.get_hash_new_full_blocks()

119
    @classmethod
120
121
122
123
    def from_engine_core_request(
        cls, request: EngineCoreRequest,
        block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
    ) -> "Request":
124
125
        return cls(
            request_id=request.request_id,
126
            client_index=request.client_index,
127
            prompt_token_ids=request.prompt_token_ids,
128
            mm_features=request.mm_features,
129
            sampling_params=request.sampling_params,
130
            pooling_params=request.pooling_params,
131
            eos_token_id=request.eos_token_id,
132
            arrival_time=request.arrival_time,
133
            lora_request=request.lora_request,
134
            structured_output_request=StructuredOutputRequest(
135
136
                sampling_params=request.sampling_params) \
                    if request.sampling_params else None,
137
            cache_salt=request.cache_salt,
138
            priority=request.priority,
139
            block_hasher=block_hasher,
140
141
        )

142
143
    def append_output_token_ids(
        self,
144
        token_ids: Union[int, list[int]],
145
146
    ) -> None:
        if isinstance(token_ids, int):
147
148
149
150
151
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
152

153
154
155
        if self.get_hash_new_full_blocks is not None:
            self.block_hashes.extend(self.get_hash_new_full_blocks())

156
157
158
159
    @property
    def is_output_corrupted(self) -> bool:
        return self.num_nans_in_logits > 0

160
161
    @property
    def num_tokens(self) -> int:
162
        return len(self._all_token_ids)
163

164
165
166
167
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

168
169
    @property
    def num_output_tokens(self) -> int:
170
        return len(self._output_token_ids)
171
172
173
174

    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

175
    def get_finished_reason(self) -> Union[FinishReason, None]:
176
177
        return RequestStatus.get_finished_reason(self.status)

178
179
    def get_num_encoder_tokens(self, input_id: int) -> int:
        assert input_id < len(self.mm_positions)
180
        num_tokens = self.mm_positions[input_id].length
181
182
        return num_tokens

183
184
185
186
187
188
189
190
191
192
193
194
195
    def record_event(
        self,
        event_type: EngineCoreEventType,
        timestamp: Optional[float] = None,
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

    def take_events(self) -> Optional[list[EngineCoreEvent]]:
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

196
197

class RequestStatus(enum.IntEnum):
198
    """Status of a request."""
199
200
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
201
    WAITING_FOR_REMOTE_KVS = enum.auto()
202
203
204
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
205
    # as a finished status.
206
207
208
209
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
210

211
212
213
    def __str__(self):
        return self.name

214
215
216
217
218
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
219
    def get_finished_reason(
220
            status: "RequestStatus") -> Union[FinishReason, None]:
221
222
223
224
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
225
# NOTE: The ignored requests are the requests whose prompt lengths
226
227
228
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
229
230
231
232
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
233
}