request.py 11.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections import deque
7
from collections.abc import Callable, Mapping
8
from dataclasses import dataclass
9
from functools import partial
10
from typing import TYPE_CHECKING, Any, Optional
11

12
13
import torch

14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
19
20
21
22
23
from vllm.v1.engine import (
    EngineCoreEvent,
    EngineCoreEventType,
    EngineCoreRequest,
    FinishReason,
)
24
from vllm.v1.structured_output.request import StructuredOutputRequest
25
from vllm.v1.utils import ConstantList
26

27
if TYPE_CHECKING:
28
    from vllm.lora.request import LoRARequest
29
    from vllm.v1.core.kv_cache_utils import BlockHash
30

31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@dataclass
class StreamingUpdate:
    """Lightweight data for streaming session continuation.

    Contains only the fields needed to update an existing streaming session
    with new input data.
    """

    mm_features: list[MultiModalFeatureSpec] | None
    prompt_token_ids: list[int] | None
    max_tokens: int
    arrival_time: float
    sampling_params: SamplingParams | None

    @classmethod
    def from_request(cls, request: "Request") -> "StreamingUpdate | None":
        if not request.resumable:
            return None
        return cls(
            mm_features=request.mm_features,
            prompt_token_ids=request.prompt_token_ids,
            max_tokens=request.max_tokens,
            arrival_time=request.arrival_time,
            sampling_params=request.sampling_params,
        )


59
60
61
62
class Request:
    def __init__(
        self,
        request_id: str,
63
64
65
66
        prompt_token_ids: list[int] | None,
        sampling_params: SamplingParams | None,
        pooling_params: PoolingParams | None,
        eos_token_id: int | None,
67
        client_index: int = 0,
68
69
70
        arrival_time: float | None = None,
        prompt_embeds: torch.Tensor | None = None,
        mm_features: list[MultiModalFeatureSpec] | None = None,
71
        lora_request: Optional["LoRARequest"] = None,
72
        cache_salt: str | None = None,
73
        priority: int = 0,
74
75
        trace_headers: Mapping[str, str] | None = None,
        block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
76
        resumable: bool = False,
77
78
    ) -> None:
        self.request_id = request_id
79
        self.client_index = client_index
80
        self.priority = priority
81
        self.sampling_params = sampling_params
82
        self.pooling_params = pooling_params
83
84
85
        # Because of LoRA, the eos token id can be different for each request.
        self.eos_token_id = eos_token_id
        self.lora_request = lora_request
86
87
88
        self.structured_output_request = StructuredOutputRequest.from_sampling_params(
            sampling_params
        )
89
        self.arrival_time = arrival_time if arrival_time is not None else time.time()
90

91
        self.status = RequestStatus.WAITING
92
        self.events: list[EngineCoreEvent] = []
93
        self.stop_reason: int | str | None = None
94
95

        # P/D: Connector-specific KV transfer parameters.
96
        self.kv_transfer_params: dict[str, Any] | None = None
97
98

        if pooling_params is not None:
99
            # Pooling models.
100
101
            self.max_tokens = 1
        elif sampling_params is not None:
102
            # Generative models.
103
104
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
105
            if self.structured_output_request is not None:
106
107
108
                self.status = RequestStatus.WAITING_FOR_FSM

            if sampling_params.extra_args is not None:
109
110
111
                self.kv_transfer_params = sampling_params.extra_args.get(
                    "kv_transfer_params"
                )
112
        else:
113
            raise ValueError("sampling_params and pooling_params can't both be unset")
114

115
        self.prompt_token_ids = prompt_token_ids
116
117
        self.prompt_embeds = prompt_embeds
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
118
119
            prompt_token_ids, prompt_embeds
        )
120
        self._output_token_ids: list[int] = []
121
122
123
124
125
        self._all_token_ids: list[int] = (
            self.prompt_token_ids.copy()
            if self.prompt_token_ids is not None
            else [0] * self.num_prompt_tokens
        )
126
127
128
129
130
131

        # Used in async scheduling.
        self.num_output_placeholders = 0
        # Used in forced preemption (reset_prefix_cache) with async scheduling.
        self.discard_latest_async_tokens = False

132
        self.spec_token_ids: list[int] = []
133
        self.num_computed_tokens = 0
134
135
136
137
        # Number of tokens currently stored in the KV cache for this request.
        # This may differ from `num_computed_tokens` when KV compression is
        # enabled (e.g., token-shared prefill compression).
        self.num_kv_tokens = 0
138
        self.cache_salt: str | None = cache_salt
139

140
        # Multi-modal related
141
        self.mm_features = mm_features or []
142

143
        # Read-only views
omahs's avatar
omahs committed
144
        # Prevent directly appending to these lists since
145
146
147
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)
148
149
        # trace_headers
        self.trace_headers = trace_headers
150
151
152
153
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

154
155
156
157
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

158
        # The number of times this request has been preempted by the scheduler.
159
160
        self.num_preemptions = 0

161
162
163
        # The number of tokens that have been computed remotely.
        self.num_external_computed_tokens = 0

164
        self.block_hashes: list[BlockHash] = []
165
        self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
166
167
168
169
        if block_hasher is not None:
            self.get_hash_new_full_blocks = partial(block_hasher, self)
            self.block_hashes = self.get_hash_new_full_blocks()

170
171
        self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()

172
173
174
175
176
        # Used for streaming
        self.resumable = resumable
        # None entry in the queue means finished.
        self.streaming_queue: deque[StreamingUpdate | None] | None = None

177
    @classmethod
178
    def from_engine_core_request(
179
180
        cls,
        request: EngineCoreRequest,
181
        block_hasher: Callable[["Request"], list["BlockHash"]] | None,
182
    ) -> "Request":
183
184
        return cls(
            request_id=request.request_id,
185
            client_index=request.client_index,
186
            prompt_token_ids=request.prompt_token_ids,
187
            prompt_embeds=request.prompt_embeds,
188
            mm_features=request.mm_features,
189
            sampling_params=request.sampling_params,
190
            pooling_params=request.pooling_params,
191
            eos_token_id=request.eos_token_id,
192
            arrival_time=request.arrival_time,
193
            lora_request=request.lora_request,
194
            cache_salt=request.cache_salt,
195
            priority=request.priority,
196
            trace_headers=request.trace_headers,
197
            block_hasher=block_hasher,
198
            resumable=request.resumable,
199
200
        )

201
202
    def append_output_token_ids(
        self,
203
        token_ids: int | list[int],
204
205
    ) -> None:
        if isinstance(token_ids, int):
206
207
208
209
210
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
211

212
213
214
        if self.get_hash_new_full_blocks is not None:
            self.block_hashes.extend(self.get_hash_new_full_blocks())

215
216
217
218
    @property
    def use_structured_output(self) -> bool:
        return self.structured_output_request is not None

219
220
    @property
    def num_tokens(self) -> int:
221
        return len(self._all_token_ids)
222

223
224
225
226
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

227
228
    @property
    def num_output_tokens(self) -> int:
229
        return len(self._output_token_ids)
230

231
232
233
234
235
236
237
238
    @property
    def num_encoder_inputs(self) -> int:
        return len(self.mm_features)

    @property
    def has_encoder_inputs(self) -> bool:
        return self.num_encoder_inputs > 0

239
240
241
242
243
244
245
246
247
248
249
250
251
    def get_skip_reading_prefix_cache(self) -> bool:
        if (
            self.sampling_params is not None
            and self.sampling_params.skip_reading_prefix_cache is not None
        ):
            return self.sampling_params.skip_reading_prefix_cache
        elif (
            self.pooling_params is not None
            and self.pooling_params.skip_reading_prefix_cache is not None
        ):
            return self.pooling_params.skip_reading_prefix_cache
        return False

252
253
254
    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

255
    def get_finished_reason(self) -> FinishReason | None:
256
257
        return RequestStatus.get_finished_reason(self.status)

258
    def get_num_encoder_embeds(self, input_id: int) -> int:
259
        assert input_id < len(self.mm_features)
260
        return self.mm_features[input_id].mm_position.get_num_embeds
261

262
263
264
    def record_event(
        self,
        event_type: EngineCoreEventType,
265
        timestamp: float | None = None,
266
267
268
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

269
    def take_events(self) -> list[EngineCoreEvent] | None:
270
271
272
273
274
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

275
276
277
278
279
280
281
282
283
284
285
286
287
    def __lt__(self, other: "Request") -> bool:
        """
        Compare two requests based on priority, arrival time, and request ID.
        Used in priority scheduling.
        """
        if self.priority != other.priority:
            return self.priority < other.priority
        if self.arrival_time != other.arrival_time:
            return self.arrival_time < other.arrival_time
        if self.request_id != other.request_id:
            return self.request_id < other.request_id
        return id(self) < id(other)

288
289

class RequestStatus(enum.IntEnum):
290
    """Status of a request."""
291

292
293
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
294
    WAITING_FOR_REMOTE_KVS = enum.auto()
295
    WAITING_FOR_STREAMING_REQ = enum.auto()
296
297
298
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
299
    # as a finished status.
300
301
302
303
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
304
    FINISHED_ERROR = enum.auto()
305

306
    def __str__(self) -> str:
307
308
        return self.name

309
310
311
312
313
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
314
    def get_finished_reason(status: "RequestStatus") -> FinishReason | None:
315
316
317
318
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
319
# NOTE: The ignored requests are the requests whose prompt lengths
320
321
322
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
323
324
325
326
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
327
    RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
328
    RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
329
}