request.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections import deque
7
from collections.abc import Callable, Mapping
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10

11
import torch
12
from typing_extensions import deprecated
13

14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
19
20
21
22
23
from vllm.v1.engine import (
    EngineCoreEvent,
    EngineCoreEventType,
    EngineCoreRequest,
    FinishReason,
)
24
from vllm.v1.structured_output.request import StructuredOutputRequest
25
from vllm.v1.utils import ConstantList
26

27
if TYPE_CHECKING:
28
    from vllm.lora.request import LoRARequest
29
    from vllm.v1.core.kv_cache_utils import BlockHash
30

31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@dataclass
class StreamingUpdate:
    """Lightweight data for streaming session continuation.

    Contains only the fields needed to update an existing streaming session
    with new input data.
    """

    mm_features: list[MultiModalFeatureSpec] | None
    prompt_token_ids: list[int] | None
    max_tokens: int
    arrival_time: float
    sampling_params: SamplingParams | None

    @classmethod
    def from_request(cls, request: "Request") -> "StreamingUpdate | None":
        if not request.resumable:
            return None
        return cls(
            mm_features=request.mm_features,
            prompt_token_ids=request.prompt_token_ids,
            max_tokens=request.max_tokens,
            arrival_time=request.arrival_time,
            sampling_params=request.sampling_params,
        )


59
60
61
62
class Request:
    def __init__(
        self,
        request_id: str,
63
64
65
        prompt_token_ids: list[int] | None,
        sampling_params: SamplingParams | None,
        pooling_params: PoolingParams | None,
66
        client_index: int = 0,
67
68
69
        arrival_time: float | None = None,
        prompt_embeds: torch.Tensor | None = None,
        mm_features: list[MultiModalFeatureSpec] | None = None,
70
        lora_request: "LoRARequest | None" = None,
71
        cache_salt: str | None = None,
72
        priority: int = 0,
73
74
        trace_headers: Mapping[str, str] | None = None,
        block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
75
        resumable: bool = False,
76
        reasoning_ended: bool | None = None,
77
78
    ) -> None:
        self.request_id = request_id
79
        self.client_index = client_index
80
        self.priority = priority
81
        self.sampling_params = sampling_params
82
        self.pooling_params = pooling_params
83
        self.lora_request = lora_request
84
85
86
        self.structured_output_request = StructuredOutputRequest.from_sampling_params(
            sampling_params
        )
87
88
        if self.structured_output_request is not None:
            self.structured_output_request.reasoning_ended = reasoning_ended
89
        self.arrival_time = arrival_time if arrival_time is not None else time.time()
90

91
        self.status = RequestStatus.WAITING
92
        self.events: list[EngineCoreEvent] = []
93
        self.stop_reason: int | str | None = None
94
95

        # P/D: Connector-specific KV transfer parameters.
96
        self.kv_transfer_params: dict[str, Any] | None = None
97
98

        if pooling_params is not None:
99
            # Pooling models.
100
101
            self.max_tokens = 1
        elif sampling_params is not None:
102
            # Generative models.
103
104
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
105
            if self.structured_output_request is not None:
106
107
108
                self.status = RequestStatus.WAITING_FOR_FSM

            if sampling_params.extra_args is not None:
109
110
111
                self.kv_transfer_params = sampling_params.extra_args.get(
                    "kv_transfer_params"
                )
112
        else:
113
            raise ValueError("sampling_params and pooling_params can't both be unset")
114

115
        self.prompt_token_ids = prompt_token_ids
116
        self.prompt_embeds = prompt_embeds
117
118
119
        # Cache per-block prompt-embed hashes to avoid rehashing the same
        # tensor slices when generating extra keys.
        self._prompt_embeds_per_block_hashes: dict[tuple[int, int], bytes] = {}
120
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
121
122
            prompt_token_ids, prompt_embeds
        )
123
        self._output_token_ids: list[int] = []
124
125
126
127
128
        self._all_token_ids: list[int] = (
            self.prompt_token_ids.copy()
            if self.prompt_token_ids is not None
            else [0] * self.num_prompt_tokens
        )
129
130
131
132
133
134

        # Used in async scheduling.
        self.num_output_placeholders = 0
        # Used in forced preemption (reset_prefix_cache) with async scheduling.
        self.discard_latest_async_tokens = False

135
        self.spec_token_ids: list[int] = []
136
        self.num_computed_tokens = 0
137
        self.cache_salt: str | None = cache_salt
138

139
        # Multi-modal related
140
        self.mm_features = mm_features or []
141

142
        # Read-only views
omahs's avatar
omahs committed
143
        # Prevent directly appending to these lists since
144
145
146
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)
147
148
        # trace_headers
        self.trace_headers = trace_headers
149
150
151
152
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

153
154
155
        # True if this request is scheduled as a non-final prefill chunk.
        self.is_prefill_chunk = False

156
157
158
159
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

160
        # The number of times this request has been preempted by the scheduler.
161
162
        self.num_preemptions = 0

163
164
165
        # The number of tokens that have been computed remotely.
        self.num_external_computed_tokens = 0

166
        self.block_hashes: list[BlockHash] = []
167
168
169
170
171
        # Store the block hasher without binding self to avoid creating a
        # reference cycle (Request -> partial -> Request) that prevents
        # immediate garbage collection via reference counting.
        self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher
        self.update_block_hashes()
172

173
174
        self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()

175
176
177
178
179
        # Used for streaming
        self.resumable = resumable
        # None entry in the queue means finished.
        self.streaming_queue: deque[StreamingUpdate | None] | None = None

180
181
182
183
184
185
186
187
188
189
190
    @property
    @deprecated(
        "Request.eos_token_id will be removed in v0.18. "
        "Please use Request.sampling_params.eos_token_id instead."
    )
    def eos_token_id(self) -> int | None:
        if self.sampling_params is None:
            return None

        return self.sampling_params.eos_token_id

191
    @classmethod
192
    def from_engine_core_request(
193
194
        cls,
        request: EngineCoreRequest,
195
        block_hasher: Callable[["Request"], list["BlockHash"]] | None,
196
    ) -> "Request":
197
198
        return cls(
            request_id=request.request_id,
199
            client_index=request.client_index,
200
            prompt_token_ids=request.prompt_token_ids,
201
            prompt_embeds=request.prompt_embeds,
202
            mm_features=request.mm_features,
203
            sampling_params=request.sampling_params,
204
            pooling_params=request.pooling_params,
205
            arrival_time=request.arrival_time,
206
            lora_request=request.lora_request,
207
            cache_salt=request.cache_salt,
208
            priority=request.priority,
209
            trace_headers=request.trace_headers,
210
            block_hasher=block_hasher,
211
            resumable=request.resumable,
212
            reasoning_ended=request.reasoning_ended,
213
214
        )

215
216
    def append_output_token_ids(
        self,
217
        token_ids: int | list[int],
218
219
    ) -> None:
        if isinstance(token_ids, int):
220
221
222
223
224
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
225

226
227
228
229
230
231
        self.update_block_hashes()

    def update_block_hashes(self) -> None:
        """Compute block hashes for any new full blocks and append them."""
        if self._block_hasher is not None:
            self.block_hashes.extend(self._block_hasher(self))
232

233
234
235
236
    @property
    def use_structured_output(self) -> bool:
        return self.structured_output_request is not None

237
238
    @property
    def num_tokens(self) -> int:
239
        return len(self._all_token_ids)
240

241
242
243
244
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

245
246
    @property
    def num_output_tokens(self) -> int:
247
        return len(self._output_token_ids)
248

249
250
251
252
253
254
255
256
    @property
    def num_encoder_inputs(self) -> int:
        return len(self.mm_features)

    @property
    def has_encoder_inputs(self) -> bool:
        return self.num_encoder_inputs > 0

257
258
259
260
261
262
263
264
265
266
267
268
269
    def get_skip_reading_prefix_cache(self) -> bool:
        if (
            self.sampling_params is not None
            and self.sampling_params.skip_reading_prefix_cache is not None
        ):
            return self.sampling_params.skip_reading_prefix_cache
        elif (
            self.pooling_params is not None
            and self.pooling_params.skip_reading_prefix_cache is not None
        ):
            return self.pooling_params.skip_reading_prefix_cache
        return False

270
271
272
    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

273
    def get_finished_reason(self) -> FinishReason | None:
274
275
        return RequestStatus.get_finished_reason(self.status)

276
    def get_num_encoder_embeds(self, input_id: int) -> int:
277
        assert input_id < len(self.mm_features)
278
        return self.mm_features[input_id].mm_position.get_num_embeds()
279

280
281
282
    def record_event(
        self,
        event_type: EngineCoreEventType,
283
        timestamp: float | None = None,
284
285
286
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

287
    def take_events(self) -> list[EngineCoreEvent] | None:
288
289
290
291
292
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

293
294
295
296
297
298
299
300
301
302
303
304
305
    def __lt__(self, other: "Request") -> bool:
        """
        Compare two requests based on priority, arrival time, and request ID.
        Used in priority scheduling.
        """
        if self.priority != other.priority:
            return self.priority < other.priority
        if self.arrival_time != other.arrival_time:
            return self.arrival_time < other.arrival_time
        if self.request_id != other.request_id:
            return self.request_id < other.request_id
        return id(self) < id(other)

306
307

class RequestStatus(enum.IntEnum):
308
    """Status of a request."""
309

310
311
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
312
    WAITING_FOR_REMOTE_KVS = enum.auto()
313
    WAITING_FOR_STREAMING_REQ = enum.auto()
314
315
316
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
317
    # as a finished status.
318
319
320
321
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
322
    FINISHED_ERROR = enum.auto()
323

324
    def __str__(self) -> str:
325
326
        return self.name

327
328
329
330
331
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
332
    def get_finished_reason(status: "RequestStatus") -> FinishReason | None:
333
334
335
336
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
337
# NOTE: The ignored requests are the requests whose prompt lengths
338
339
340
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
341
342
343
344
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
345
    RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
346
    RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
347
}