request.py 11.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections import deque
7
from collections.abc import Callable, Mapping
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10

11
12
import torch

13
from vllm.multimodal.inputs import MultiModalFeatureSpec
14
from vllm.pooling_params import PoolingParams
15
from vllm.sampling_params import SamplingParams
16
from vllm.utils import length_from_prompt_token_ids_or_embeds
17
18
19
20
21
22
from vllm.v1.engine import (
    EngineCoreEvent,
    EngineCoreEventType,
    EngineCoreRequest,
    FinishReason,
)
23
from vllm.v1.structured_output.request import StructuredOutputRequest
24
from vllm.v1.utils import ConstantList
25

26
if TYPE_CHECKING:
27
    from vllm.lora.request import LoRARequest
28
    from vllm.v1.core.kv_cache_utils import BlockHash
29

30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@dataclass
class StreamingUpdate:
    """Lightweight data for streaming session continuation.

    Contains only the fields needed to update an existing streaming session
    with new input data.
    """

    mm_features: list[MultiModalFeatureSpec] | None
    prompt_token_ids: list[int] | None
    max_tokens: int
    arrival_time: float
    sampling_params: SamplingParams | None

    @classmethod
    def from_request(cls, request: "Request") -> "StreamingUpdate | None":
        if not request.resumable:
            return None
        return cls(
            mm_features=request.mm_features,
            prompt_token_ids=request.prompt_token_ids,
            max_tokens=request.max_tokens,
            arrival_time=request.arrival_time,
            sampling_params=request.sampling_params,
        )


58
59
60
61
class Request:
    def __init__(
        self,
        request_id: str,
62
63
64
        prompt_token_ids: list[int] | None,
        sampling_params: SamplingParams | None,
        pooling_params: PoolingParams | None,
65
        client_index: int = 0,
66
67
68
        arrival_time: float | None = None,
        prompt_embeds: torch.Tensor | None = None,
        mm_features: list[MultiModalFeatureSpec] | None = None,
69
        lora_request: "LoRARequest | None" = None,
70
        cache_salt: str | None = None,
71
        priority: int = 0,
72
73
        trace_headers: Mapping[str, str] | None = None,
        block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
74
        resumable: bool = False,
75
        reasoning_ended: bool | None = None,
76
77
    ) -> None:
        self.request_id = request_id
78
        self.client_index = client_index
79
        self.priority = priority
80
        self.sampling_params = sampling_params
81
        self.pooling_params = pooling_params
82
        self.lora_request = lora_request
83
84
85
        self.structured_output_request = StructuredOutputRequest.from_sampling_params(
            sampling_params
        )
86
87
        if self.structured_output_request is not None:
            self.structured_output_request.reasoning_ended = reasoning_ended
88
        self.arrival_time = arrival_time if arrival_time is not None else time.time()
89

90
        self.status = RequestStatus.WAITING
91
        self.events: list[EngineCoreEvent] = []
92
        self.stop_reason: int | str | None = None
93
94

        # P/D: Connector-specific KV transfer parameters.
95
        self.kv_transfer_params: dict[str, Any] | None = None
96
97

        if pooling_params is not None:
98
            # Pooling models.
99
100
            self.max_tokens = 1
        elif sampling_params is not None:
101
            # Generative models.
102
103
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
104
            if self.structured_output_request is not None:
105
106
107
                self.status = RequestStatus.WAITING_FOR_FSM

            if sampling_params.extra_args is not None:
108
109
110
                self.kv_transfer_params = sampling_params.extra_args.get(
                    "kv_transfer_params"
                )
111
        else:
112
            raise ValueError("sampling_params and pooling_params can't both be unset")
113

114
        self.prompt_token_ids = prompt_token_ids
115
        self.prompt_embeds = prompt_embeds
116
117
118
        # Cache per-block prompt-embed hashes to avoid rehashing the same
        # tensor slices when generating extra keys.
        self._prompt_embeds_per_block_hashes: dict[tuple[int, int], bytes] = {}
119
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
120
121
            prompt_token_ids, prompt_embeds
        )
122
        self._output_token_ids: list[int] = []
123
124
125
126
127
        self._all_token_ids: list[int] = (
            self.prompt_token_ids.copy()
            if self.prompt_token_ids is not None
            else [0] * self.num_prompt_tokens
        )
128
129
130
131
132
133

        # Used in async scheduling.
        self.num_output_placeholders = 0
        # Used in forced preemption (reset_prefix_cache) with async scheduling.
        self.discard_latest_async_tokens = False

134
        self.spec_token_ids: list[int] = []
135
        self.num_computed_tokens = 0
136
        self.cache_salt: str | None = cache_salt
137

138
        # Multi-modal related
139
        self.mm_features = mm_features or []
140

141
        # Read-only views
omahs's avatar
omahs committed
142
        # Prevent directly appending to these lists since
143
144
145
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)
146
147
        # trace_headers
        self.trace_headers = trace_headers
148
149
150
151
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

152
153
154
        # True if this request is scheduled as a non-final prefill chunk.
        self.is_prefill_chunk = False

155
156
157
158
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

159
        # The number of times this request has been preempted by the scheduler.
160
161
        self.num_preemptions = 0

162
163
164
        # The number of tokens that have been computed remotely.
        self.num_external_computed_tokens = 0

165
        self.block_hashes: list[BlockHash] = []
166
167
168
169
170
        # Store the block hasher without binding self to avoid creating a
        # reference cycle (Request -> partial -> Request) that prevents
        # immediate garbage collection via reference counting.
        self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher
        self.update_block_hashes()
171

172
173
        self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()

174
175
176
177
178
        # Used for streaming
        self.resumable = resumable
        # None entry in the queue means finished.
        self.streaming_queue: deque[StreamingUpdate | None] | None = None

179
    @classmethod
180
    def from_engine_core_request(
181
182
        cls,
        request: EngineCoreRequest,
183
        block_hasher: Callable[["Request"], list["BlockHash"]] | None,
184
    ) -> "Request":
185
186
        return cls(
            request_id=request.request_id,
187
            client_index=request.client_index,
188
            prompt_token_ids=request.prompt_token_ids,
189
            prompt_embeds=request.prompt_embeds,
190
            mm_features=request.mm_features,
191
            sampling_params=request.sampling_params,
192
            pooling_params=request.pooling_params,
193
            arrival_time=request.arrival_time,
194
            lora_request=request.lora_request,
195
            cache_salt=request.cache_salt,
196
            priority=request.priority,
197
            trace_headers=request.trace_headers,
198
            block_hasher=block_hasher,
199
            resumable=request.resumable,
200
            reasoning_ended=request.reasoning_ended,
201
202
        )

203
204
    def append_output_token_ids(
        self,
205
        token_ids: int | list[int],
206
207
    ) -> None:
        if isinstance(token_ids, int):
208
209
210
211
212
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
213

214
215
216
217
218
219
        self.update_block_hashes()

    def update_block_hashes(self) -> None:
        """Compute block hashes for any new full blocks and append them."""
        if self._block_hasher is not None:
            self.block_hashes.extend(self._block_hasher(self))
220

221
222
223
224
    @property
    def use_structured_output(self) -> bool:
        return self.structured_output_request is not None

225
226
    @property
    def num_tokens(self) -> int:
227
        return len(self._all_token_ids)
228

229
230
231
232
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

233
234
    @property
    def num_output_tokens(self) -> int:
235
        return len(self._output_token_ids)
236

237
238
239
240
241
242
243
244
    @property
    def num_encoder_inputs(self) -> int:
        return len(self.mm_features)

    @property
    def has_encoder_inputs(self) -> bool:
        return self.num_encoder_inputs > 0

245
246
247
248
249
250
251
252
253
254
255
256
257
    def get_skip_reading_prefix_cache(self) -> bool:
        if (
            self.sampling_params is not None
            and self.sampling_params.skip_reading_prefix_cache is not None
        ):
            return self.sampling_params.skip_reading_prefix_cache
        elif (
            self.pooling_params is not None
            and self.pooling_params.skip_reading_prefix_cache is not None
        ):
            return self.pooling_params.skip_reading_prefix_cache
        return False

258
259
260
    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

261
    def get_finished_reason(self) -> FinishReason | None:
262
263
        return RequestStatus.get_finished_reason(self.status)

264
    def get_num_encoder_embeds(self, input_id: int) -> int:
265
        assert input_id < len(self.mm_features)
266
        return self.mm_features[input_id].mm_position.get_num_embeds()
267

268
269
270
    def record_event(
        self,
        event_type: EngineCoreEventType,
271
        timestamp: float | None = None,
272
273
274
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

275
    def take_events(self) -> list[EngineCoreEvent] | None:
276
277
278
279
280
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

281
282
283
284
285
286
287
288
289
290
291
292
293
    def __lt__(self, other: "Request") -> bool:
        """
        Compare two requests based on priority, arrival time, and request ID.
        Used in priority scheduling.
        """
        if self.priority != other.priority:
            return self.priority < other.priority
        if self.arrival_time != other.arrival_time:
            return self.arrival_time < other.arrival_time
        if self.request_id != other.request_id:
            return self.request_id < other.request_id
        return id(self) < id(other)

294
295

class RequestStatus(enum.IntEnum):
296
    """Status of a request."""
297

298
299
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
300
    WAITING_FOR_REMOTE_KVS = enum.auto()
301
    WAITING_FOR_STREAMING_REQ = enum.auto()
302
303
304
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
305
    # as a finished status.
306
307
308
309
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
310
    FINISHED_ERROR = enum.auto()
311
    FINISHED_REPETITION = enum.auto()
312

313
    def __str__(self) -> str:
314
315
        return self.name

316
317
318
319
320
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
321
    def get_finished_reason(status: "RequestStatus") -> FinishReason | None:
322
323
324
325
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
326
# NOTE: The ignored requests are the requests whose prompt lengths
327
328
329
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
330
331
332
333
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
334
    RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
335
    RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
336
    RequestStatus.FINISHED_REPETITION: FinishReason.REPETITION,
337
}