request.py 11.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections import deque
7
from collections.abc import Callable, Mapping
8
from dataclasses import dataclass
9
from functools import partial
10
from typing import TYPE_CHECKING, Any
11

12
13
import torch

14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
19
20
21
22
23
from vllm.v1.engine import (
    EngineCoreEvent,
    EngineCoreEventType,
    EngineCoreRequest,
    FinishReason,
)
24
from vllm.v1.structured_output.request import StructuredOutputRequest
25
from vllm.v1.utils import ConstantList
26

27
if TYPE_CHECKING:
28
    from vllm.lora.request import LoRARequest
29
    from vllm.v1.core.kv_cache_utils import BlockHash
30

31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@dataclass
class StreamingUpdate:
    """Lightweight data for streaming session continuation.

    Contains only the fields needed to update an existing streaming session
    with new input data.
    """

    mm_features: list[MultiModalFeatureSpec] | None
    prompt_token_ids: list[int] | None
    max_tokens: int
    arrival_time: float
    sampling_params: SamplingParams | None

    @classmethod
    def from_request(cls, request: "Request") -> "StreamingUpdate | None":
        if not request.resumable:
            return None
        return cls(
            mm_features=request.mm_features,
            prompt_token_ids=request.prompt_token_ids,
            max_tokens=request.max_tokens,
            arrival_time=request.arrival_time,
            sampling_params=request.sampling_params,
        )


59
60
61
62
class Request:
    def __init__(
        self,
        request_id: str,
63
64
65
66
        prompt_token_ids: list[int] | None,
        sampling_params: SamplingParams | None,
        pooling_params: PoolingParams | None,
        eos_token_id: int | None,
67
        client_index: int = 0,
68
69
70
        arrival_time: float | None = None,
        prompt_embeds: torch.Tensor | None = None,
        mm_features: list[MultiModalFeatureSpec] | None = None,
71
        lora_request: "LoRARequest | None" = None,
72
        cache_salt: str | None = None,
73
        priority: int = 0,
74
75
        trace_headers: Mapping[str, str] | None = None,
        block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
76
        resumable: bool = False,
77
        reasoning_ended: bool | None = None,
78
79
    ) -> None:
        self.request_id = request_id
80
        self.client_index = client_index
81
        self.priority = priority
82
        self.sampling_params = sampling_params
83
        self.pooling_params = pooling_params
84
85
86
        # Because of LoRA, the eos token id can be different for each request.
        self.eos_token_id = eos_token_id
        self.lora_request = lora_request
87
88
89
        self.structured_output_request = StructuredOutputRequest.from_sampling_params(
            sampling_params
        )
90
91
        if self.structured_output_request is not None:
            self.structured_output_request.reasoning_ended = reasoning_ended
92
        self.arrival_time = arrival_time if arrival_time is not None else time.time()
93

94
        self.status = RequestStatus.WAITING
95
        self.events: list[EngineCoreEvent] = []
96
        self.stop_reason: int | str | None = None
97
98

        # P/D: Connector-specific KV transfer parameters.
99
        self.kv_transfer_params: dict[str, Any] | None = None
100
101

        if pooling_params is not None:
102
            # Pooling models.
103
104
            self.max_tokens = 1
        elif sampling_params is not None:
105
            # Generative models.
106
107
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
108
            if self.structured_output_request is not None:
109
110
111
                self.status = RequestStatus.WAITING_FOR_FSM

            if sampling_params.extra_args is not None:
112
113
114
                self.kv_transfer_params = sampling_params.extra_args.get(
                    "kv_transfer_params"
                )
115
        else:
116
            raise ValueError("sampling_params and pooling_params can't both be unset")
117

118
        self.prompt_token_ids = prompt_token_ids
119
120
        self.prompt_embeds = prompt_embeds
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
121
122
            prompt_token_ids, prompt_embeds
        )
123
        self._output_token_ids: list[int] = []
124
125
126
127
128
        self._all_token_ids: list[int] = (
            self.prompt_token_ids.copy()
            if self.prompt_token_ids is not None
            else [0] * self.num_prompt_tokens
        )
129
130
131
132
133
134

        # Used in async scheduling.
        self.num_output_placeholders = 0
        # Used in forced preemption (reset_prefix_cache) with async scheduling.
        self.discard_latest_async_tokens = False

135
        self.spec_token_ids: list[int] = []
136
        self.num_computed_tokens = 0
137
        self.cache_salt: str | None = cache_salt
138

139
        # Multi-modal related
140
        self.mm_features = mm_features or []
141

142
        # Read-only views
omahs's avatar
omahs committed
143
        # Prevent directly appending to these lists since
144
145
146
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)
147
148
        # trace_headers
        self.trace_headers = trace_headers
149
150
151
152
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

153
154
155
        # True if this request is scheduled as a non-final prefill chunk.
        self.is_prefill_chunk = False

156
157
158
159
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

160
        # The number of times this request has been preempted by the scheduler.
161
162
        self.num_preemptions = 0

163
164
165
        # The number of tokens that have been computed remotely.
        self.num_external_computed_tokens = 0

166
        self.block_hashes: list[BlockHash] = []
167
        self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
168
169
170
171
        if block_hasher is not None:
            self.get_hash_new_full_blocks = partial(block_hasher, self)
            self.block_hashes = self.get_hash_new_full_blocks()

172
173
        self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()

174
175
176
177
178
        # Used for streaming
        self.resumable = resumable
        # None entry in the queue means finished.
        self.streaming_queue: deque[StreamingUpdate | None] | None = None

179
    @classmethod
180
    def from_engine_core_request(
181
182
        cls,
        request: EngineCoreRequest,
183
        block_hasher: Callable[["Request"], list["BlockHash"]] | None,
184
    ) -> "Request":
185
186
        return cls(
            request_id=request.request_id,
187
            client_index=request.client_index,
188
            prompt_token_ids=request.prompt_token_ids,
189
            prompt_embeds=request.prompt_embeds,
190
            mm_features=request.mm_features,
191
            sampling_params=request.sampling_params,
192
            pooling_params=request.pooling_params,
193
            eos_token_id=request.eos_token_id,
194
            arrival_time=request.arrival_time,
195
            lora_request=request.lora_request,
196
            cache_salt=request.cache_salt,
197
            priority=request.priority,
198
            trace_headers=request.trace_headers,
199
            block_hasher=block_hasher,
200
            resumable=request.resumable,
201
            reasoning_ended=request.reasoning_ended,
202
203
        )

204
205
    def append_output_token_ids(
        self,
206
        token_ids: int | list[int],
207
208
    ) -> None:
        if isinstance(token_ids, int):
209
210
211
212
213
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
214

215
216
217
        if self.get_hash_new_full_blocks is not None:
            self.block_hashes.extend(self.get_hash_new_full_blocks())

218
219
220
221
    @property
    def use_structured_output(self) -> bool:
        return self.structured_output_request is not None

222
223
    @property
    def num_tokens(self) -> int:
224
        return len(self._all_token_ids)
225

226
227
228
229
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

230
231
    @property
    def num_output_tokens(self) -> int:
232
        return len(self._output_token_ids)
233

234
235
236
237
238
239
240
241
    @property
    def num_encoder_inputs(self) -> int:
        return len(self.mm_features)

    @property
    def has_encoder_inputs(self) -> bool:
        return self.num_encoder_inputs > 0

242
243
244
245
246
247
248
249
250
251
252
253
254
    def get_skip_reading_prefix_cache(self) -> bool:
        if (
            self.sampling_params is not None
            and self.sampling_params.skip_reading_prefix_cache is not None
        ):
            return self.sampling_params.skip_reading_prefix_cache
        elif (
            self.pooling_params is not None
            and self.pooling_params.skip_reading_prefix_cache is not None
        ):
            return self.pooling_params.skip_reading_prefix_cache
        return False

255
256
257
    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

258
    def get_finished_reason(self) -> FinishReason | None:
259
260
        return RequestStatus.get_finished_reason(self.status)

261
    def get_num_encoder_embeds(self, input_id: int) -> int:
262
        assert input_id < len(self.mm_features)
263
        return self.mm_features[input_id].mm_position.get_num_embeds()
264

265
266
267
    def record_event(
        self,
        event_type: EngineCoreEventType,
268
        timestamp: float | None = None,
269
270
271
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

272
    def take_events(self) -> list[EngineCoreEvent] | None:
273
274
275
276
277
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

278
279
280
281
282
283
284
285
286
287
288
289
290
    def __lt__(self, other: "Request") -> bool:
        """
        Compare two requests based on priority, arrival time, and request ID.
        Used in priority scheduling.
        """
        if self.priority != other.priority:
            return self.priority < other.priority
        if self.arrival_time != other.arrival_time:
            return self.arrival_time < other.arrival_time
        if self.request_id != other.request_id:
            return self.request_id < other.request_id
        return id(self) < id(other)

291
292

class RequestStatus(enum.IntEnum):
293
    """Status of a request."""
294

295
296
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
297
    WAITING_FOR_REMOTE_KVS = enum.auto()
298
    WAITING_FOR_STREAMING_REQ = enum.auto()
299
300
301
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
302
    # as a finished status.
303
304
305
306
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
307
    FINISHED_ERROR = enum.auto()
308

309
    def __str__(self) -> str:
310
311
        return self.name

312
313
314
315
316
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
317
    def get_finished_reason(status: "RequestStatus") -> FinishReason | None:
318
319
320
321
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
322
# NOTE: The ignored requests are the requests whose prompt lengths
323
324
325
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
326
327
328
329
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
330
    RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
331
    RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
332
}