request.py 11.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections import deque
7
from collections.abc import Callable, Mapping
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10

11
import torch
12
from typing_extensions import deprecated
13

14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
19
20
21
22
23
from vllm.v1.engine import (
    EngineCoreEvent,
    EngineCoreEventType,
    EngineCoreRequest,
    FinishReason,
)
24
from vllm.v1.structured_output.request import StructuredOutputRequest
25
from vllm.v1.utils import ConstantList
26

27
if TYPE_CHECKING:
28
    from vllm.lora.request import LoRARequest
29
    from vllm.v1.core.kv_cache_utils import BlockHash
30

31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@dataclass
class StreamingUpdate:
    """Lightweight data for streaming session continuation.

    Contains only the fields needed to update an existing streaming session
    with new input data.
    """

    mm_features: list[MultiModalFeatureSpec] | None
    prompt_token_ids: list[int] | None
    max_tokens: int
    arrival_time: float
    sampling_params: SamplingParams | None

    @classmethod
    def from_request(cls, request: "Request") -> "StreamingUpdate | None":
        if not request.resumable:
            return None
        return cls(
            mm_features=request.mm_features,
            prompt_token_ids=request.prompt_token_ids,
            max_tokens=request.max_tokens,
            arrival_time=request.arrival_time,
            sampling_params=request.sampling_params,
        )


59
60
61
62
class Request:
    def __init__(
        self,
        request_id: str,
63
64
65
        prompt_token_ids: list[int] | None,
        sampling_params: SamplingParams | None,
        pooling_params: PoolingParams | None,
66
        client_index: int = 0,
67
68
69
        arrival_time: float | None = None,
        prompt_embeds: torch.Tensor | None = None,
        mm_features: list[MultiModalFeatureSpec] | None = None,
70
        lora_request: "LoRARequest | None" = None,
71
        cache_salt: str | None = None,
72
        priority: int = 0,
73
74
        trace_headers: Mapping[str, str] | None = None,
        block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
75
        resumable: bool = False,
76
        reasoning_ended: bool | None = None,
77
78
    ) -> None:
        self.request_id = request_id
79
        self.client_index = client_index
80
        self.priority = priority
81
        self.sampling_params = sampling_params
82
        self.pooling_params = pooling_params
83
        self.lora_request = lora_request
84
85
86
        self.structured_output_request = StructuredOutputRequest.from_sampling_params(
            sampling_params
        )
87
88
        if self.structured_output_request is not None:
            self.structured_output_request.reasoning_ended = reasoning_ended
89
        self.arrival_time = arrival_time if arrival_time is not None else time.time()
90

91
        self.status = RequestStatus.WAITING
92
        self.events: list[EngineCoreEvent] = []
93
        self.stop_reason: int | str | None = None
94
95

        # P/D: Connector-specific KV transfer parameters.
96
        self.kv_transfer_params: dict[str, Any] | None = None
97
98

        if pooling_params is not None:
99
            # Pooling models.
100
101
            self.max_tokens = 1
        elif sampling_params is not None:
102
            # Generative models.
103
104
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
105
            if self.structured_output_request is not None:
106
107
108
                self.status = RequestStatus.WAITING_FOR_FSM

            if sampling_params.extra_args is not None:
109
110
111
                self.kv_transfer_params = sampling_params.extra_args.get(
                    "kv_transfer_params"
                )
112
        else:
113
            raise ValueError("sampling_params and pooling_params can't both be unset")
114

115
        self.prompt_token_ids = prompt_token_ids
116
117
        self.prompt_embeds = prompt_embeds
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
118
119
            prompt_token_ids, prompt_embeds
        )
120
        self._output_token_ids: list[int] = []
121
122
123
124
125
        self._all_token_ids: list[int] = (
            self.prompt_token_ids.copy()
            if self.prompt_token_ids is not None
            else [0] * self.num_prompt_tokens
        )
126
127
128
129
130
131

        # Used in async scheduling.
        self.num_output_placeholders = 0
        # Used in forced preemption (reset_prefix_cache) with async scheduling.
        self.discard_latest_async_tokens = False

132
        self.spec_token_ids: list[int] = []
133
        self.num_computed_tokens = 0
134
        self.cache_salt: str | None = cache_salt
135

136
        # Multi-modal related
137
        self.mm_features = mm_features or []
138

139
        # Read-only views
omahs's avatar
omahs committed
140
        # Prevent directly appending to these lists since
141
142
143
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)
144
145
        # trace_headers
        self.trace_headers = trace_headers
146
147
148
149
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

150
151
152
        # True if this request is scheduled as a non-final prefill chunk.
        self.is_prefill_chunk = False

153
154
155
156
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

157
        # The number of times this request has been preempted by the scheduler.
158
159
        self.num_preemptions = 0

160
161
162
        # The number of tokens that have been computed remotely.
        self.num_external_computed_tokens = 0

163
        self.block_hashes: list[BlockHash] = []
164
165
166
167
168
        # Store the block hasher without binding self to avoid creating a
        # reference cycle (Request -> partial -> Request) that prevents
        # immediate garbage collection via reference counting.
        self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher
        self.update_block_hashes()
169

170
171
        self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()

172
173
174
175
176
        # Used for streaming
        self.resumable = resumable
        # None entry in the queue means finished.
        self.streaming_queue: deque[StreamingUpdate | None] | None = None

177
178
179
180
181
182
183
184
185
186
187
    @property
    @deprecated(
        "Request.eos_token_id will be removed in v0.18. "
        "Please use Request.sampling_params.eos_token_id instead."
    )
    def eos_token_id(self) -> int | None:
        if self.sampling_params is None:
            return None

        return self.sampling_params.eos_token_id

188
    @classmethod
189
    def from_engine_core_request(
190
191
        cls,
        request: EngineCoreRequest,
192
        block_hasher: Callable[["Request"], list["BlockHash"]] | None,
193
    ) -> "Request":
194
195
        return cls(
            request_id=request.request_id,
196
            client_index=request.client_index,
197
            prompt_token_ids=request.prompt_token_ids,
198
            prompt_embeds=request.prompt_embeds,
199
            mm_features=request.mm_features,
200
            sampling_params=request.sampling_params,
201
            pooling_params=request.pooling_params,
202
            arrival_time=request.arrival_time,
203
            lora_request=request.lora_request,
204
            cache_salt=request.cache_salt,
205
            priority=request.priority,
206
            trace_headers=request.trace_headers,
207
            block_hasher=block_hasher,
208
            resumable=request.resumable,
209
            reasoning_ended=request.reasoning_ended,
210
211
        )

212
213
    def append_output_token_ids(
        self,
214
        token_ids: int | list[int],
215
216
    ) -> None:
        if isinstance(token_ids, int):
217
218
219
220
221
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
222

223
224
225
226
227
228
        self.update_block_hashes()

    def update_block_hashes(self) -> None:
        """Compute block hashes for any new full blocks and append them."""
        if self._block_hasher is not None:
            self.block_hashes.extend(self._block_hasher(self))
229

230
231
232
233
    @property
    def use_structured_output(self) -> bool:
        return self.structured_output_request is not None

234
235
    @property
    def num_tokens(self) -> int:
236
        return len(self._all_token_ids)
237

238
239
240
241
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

242
243
    @property
    def num_output_tokens(self) -> int:
244
        return len(self._output_token_ids)
245

246
247
248
249
250
251
252
253
    @property
    def num_encoder_inputs(self) -> int:
        return len(self.mm_features)

    @property
    def has_encoder_inputs(self) -> bool:
        return self.num_encoder_inputs > 0

254
255
256
257
258
259
260
261
262
263
264
265
266
    def get_skip_reading_prefix_cache(self) -> bool:
        if (
            self.sampling_params is not None
            and self.sampling_params.skip_reading_prefix_cache is not None
        ):
            return self.sampling_params.skip_reading_prefix_cache
        elif (
            self.pooling_params is not None
            and self.pooling_params.skip_reading_prefix_cache is not None
        ):
            return self.pooling_params.skip_reading_prefix_cache
        return False

267
268
269
    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

270
    def get_finished_reason(self) -> FinishReason | None:
271
272
        return RequestStatus.get_finished_reason(self.status)

273
    def get_num_encoder_embeds(self, input_id: int) -> int:
274
        assert input_id < len(self.mm_features)
275
        return self.mm_features[input_id].mm_position.get_num_embeds()
276

277
278
279
    def record_event(
        self,
        event_type: EngineCoreEventType,
280
        timestamp: float | None = None,
281
282
283
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

284
    def take_events(self) -> list[EngineCoreEvent] | None:
285
286
287
288
289
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

290
291
292
293
294
295
296
297
298
299
300
301
302
    def __lt__(self, other: "Request") -> bool:
        """
        Compare two requests based on priority, arrival time, and request ID.
        Used in priority scheduling.
        """
        if self.priority != other.priority:
            return self.priority < other.priority
        if self.arrival_time != other.arrival_time:
            return self.arrival_time < other.arrival_time
        if self.request_id != other.request_id:
            return self.request_id < other.request_id
        return id(self) < id(other)

303
304

class RequestStatus(enum.IntEnum):
305
    """Status of a request."""
306

307
308
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
309
    WAITING_FOR_REMOTE_KVS = enum.auto()
310
    WAITING_FOR_STREAMING_REQ = enum.auto()
311
312
313
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
314
    # as a finished status.
315
316
317
318
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
319
    FINISHED_ERROR = enum.auto()
320

321
    def __str__(self) -> str:
322
323
        return self.name

324
325
326
327
328
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
329
    def get_finished_reason(status: "RequestStatus") -> FinishReason | None:
330
331
332
333
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
334
# NOTE: The ignored requests are the requests whose prompt lengths
335
336
337
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
338
339
340
341
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
342
    RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
343
    RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
344
}