request.py 11.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections import deque
7
from collections.abc import Callable, Mapping
8
from dataclasses import dataclass
9
from functools import partial
10
from typing import TYPE_CHECKING, Any
11

12
13
import torch

14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams
17
from vllm.utils import length_from_prompt_token_ids_or_embeds
18
19
20
21
22
23
from vllm.v1.engine import (
    EngineCoreEvent,
    EngineCoreEventType,
    EngineCoreRequest,
    FinishReason,
)
24
from vllm.v1.structured_output.request import StructuredOutputRequest
25
from vllm.v1.utils import ConstantList
26

27
if TYPE_CHECKING:
28
    from vllm.lora.request import LoRARequest
29
    from vllm.v1.core.kv_cache_utils import BlockHash
30

31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@dataclass
class StreamingUpdate:
    """Lightweight data for streaming session continuation.

    Contains only the fields needed to update an existing streaming session
    with new input data.
    """

    mm_features: list[MultiModalFeatureSpec] | None
    prompt_token_ids: list[int] | None
    max_tokens: int
    arrival_time: float
    sampling_params: SamplingParams | None

    @classmethod
    def from_request(cls, request: "Request") -> "StreamingUpdate | None":
        if not request.resumable:
            return None
        return cls(
            mm_features=request.mm_features,
            prompt_token_ids=request.prompt_token_ids,
            max_tokens=request.max_tokens,
            arrival_time=request.arrival_time,
            sampling_params=request.sampling_params,
        )


59
60
61
62
class Request:
    def __init__(
        self,
        request_id: str,
63
64
65
66
        prompt_token_ids: list[int] | None,
        sampling_params: SamplingParams | None,
        pooling_params: PoolingParams | None,
        eos_token_id: int | None,
67
        client_index: int = 0,
68
69
70
        arrival_time: float | None = None,
        prompt_embeds: torch.Tensor | None = None,
        mm_features: list[MultiModalFeatureSpec] | None = None,
71
        lora_request: "LoRARequest | None" = None,
72
        cache_salt: str | None = None,
73
        priority: int = 0,
74
75
        trace_headers: Mapping[str, str] | None = None,
        block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
76
        resumable: bool = False,
77
78
    ) -> None:
        self.request_id = request_id
79
        self.client_index = client_index
80
        self.priority = priority
81
        self.sampling_params = sampling_params
82
        self.pooling_params = pooling_params
83
84
85
        # Because of LoRA, the eos token id can be different for each request.
        self.eos_token_id = eos_token_id
        self.lora_request = lora_request
86
87
88
        self.structured_output_request = StructuredOutputRequest.from_sampling_params(
            sampling_params
        )
89
        self.arrival_time = arrival_time if arrival_time is not None else time.time()
90

91
        self.status = RequestStatus.WAITING
92
        self.events: list[EngineCoreEvent] = []
93
        self.stop_reason: int | str | None = None
94
95

        # P/D: Connector-specific KV transfer parameters.
96
        self.kv_transfer_params: dict[str, Any] | None = None
97
98

        if pooling_params is not None:
99
            # Pooling models.
100
101
            self.max_tokens = 1
        elif sampling_params is not None:
102
            # Generative models.
103
104
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
105
            if self.structured_output_request is not None:
106
107
108
                self.status = RequestStatus.WAITING_FOR_FSM

            if sampling_params.extra_args is not None:
109
110
111
                self.kv_transfer_params = sampling_params.extra_args.get(
                    "kv_transfer_params"
                )
112
        else:
113
            raise ValueError("sampling_params and pooling_params can't both be unset")
114

115
        self.prompt_token_ids = prompt_token_ids
116
117
        self.prompt_embeds = prompt_embeds
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
118
119
            prompt_token_ids, prompt_embeds
        )
120
        self._output_token_ids: list[int] = []
121
122
123
124
125
        self._all_token_ids: list[int] = (
            self.prompt_token_ids.copy()
            if self.prompt_token_ids is not None
            else [0] * self.num_prompt_tokens
        )
126
127
128
129
130
131

        # Used in async scheduling.
        self.num_output_placeholders = 0
        # Used in forced preemption (reset_prefix_cache) with async scheduling.
        self.discard_latest_async_tokens = False

132
        self.spec_token_ids: list[int] = []
133
        self.num_computed_tokens = 0
134
        self.cache_salt: str | None = cache_salt
135

136
        # Multi-modal related
137
        self.mm_features = mm_features or []
138

139
        # Read-only views
omahs's avatar
omahs committed
140
        # Prevent directly appending to these lists since
141
142
143
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)
144
145
        # trace_headers
        self.trace_headers = trace_headers
146
147
148
149
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

150
151
152
        # True if this request is scheduled as a non-final prefill chunk.
        self.is_prefill_chunk = False

153
154
155
156
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

157
        # The number of times this request has been preempted by the scheduler.
158
159
        self.num_preemptions = 0

160
161
162
        # The number of tokens that have been computed remotely.
        self.num_external_computed_tokens = 0

163
        self.block_hashes: list[BlockHash] = []
164
        self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
165
166
167
168
        if block_hasher is not None:
            self.get_hash_new_full_blocks = partial(block_hasher, self)
            self.block_hashes = self.get_hash_new_full_blocks()

169
170
        self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()

171
172
173
174
175
        # Used for streaming
        self.resumable = resumable
        # None entry in the queue means finished.
        self.streaming_queue: deque[StreamingUpdate | None] | None = None

176
    @classmethod
177
    def from_engine_core_request(
178
179
        cls,
        request: EngineCoreRequest,
180
        block_hasher: Callable[["Request"], list["BlockHash"]] | None,
181
    ) -> "Request":
182
183
        return cls(
            request_id=request.request_id,
184
            client_index=request.client_index,
185
            prompt_token_ids=request.prompt_token_ids,
186
            prompt_embeds=request.prompt_embeds,
187
            mm_features=request.mm_features,
188
            sampling_params=request.sampling_params,
189
            pooling_params=request.pooling_params,
190
            eos_token_id=request.eos_token_id,
191
            arrival_time=request.arrival_time,
192
            lora_request=request.lora_request,
193
            cache_salt=request.cache_salt,
194
            priority=request.priority,
195
            trace_headers=request.trace_headers,
196
            block_hasher=block_hasher,
197
            resumable=request.resumable,
198
199
        )

200
201
    def append_output_token_ids(
        self,
202
        token_ids: int | list[int],
203
204
    ) -> None:
        if isinstance(token_ids, int):
205
206
207
208
209
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
210

211
212
213
        if self.get_hash_new_full_blocks is not None:
            self.block_hashes.extend(self.get_hash_new_full_blocks())

214
215
216
217
    @property
    def use_structured_output(self) -> bool:
        return self.structured_output_request is not None

218
219
    @property
    def num_tokens(self) -> int:
220
        return len(self._all_token_ids)
221

222
223
224
225
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

226
227
    @property
    def num_output_tokens(self) -> int:
228
        return len(self._output_token_ids)
229

230
231
232
233
234
235
236
237
    @property
    def num_encoder_inputs(self) -> int:
        return len(self.mm_features)

    @property
    def has_encoder_inputs(self) -> bool:
        return self.num_encoder_inputs > 0

238
239
240
241
242
243
244
245
246
247
248
249
250
    def get_skip_reading_prefix_cache(self) -> bool:
        if (
            self.sampling_params is not None
            and self.sampling_params.skip_reading_prefix_cache is not None
        ):
            return self.sampling_params.skip_reading_prefix_cache
        elif (
            self.pooling_params is not None
            and self.pooling_params.skip_reading_prefix_cache is not None
        ):
            return self.pooling_params.skip_reading_prefix_cache
        return False

251
252
253
    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

254
    def get_finished_reason(self) -> FinishReason | None:
255
256
        return RequestStatus.get_finished_reason(self.status)

257
    def get_num_encoder_embeds(self, input_id: int) -> int:
258
        assert input_id < len(self.mm_features)
259
        return self.mm_features[input_id].mm_position.get_num_embeds
260

261
262
263
    def record_event(
        self,
        event_type: EngineCoreEventType,
264
        timestamp: float | None = None,
265
266
267
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

268
    def take_events(self) -> list[EngineCoreEvent] | None:
269
270
271
272
273
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

274
275
276
277
278
279
280
281
282
283
284
285
286
    def __lt__(self, other: "Request") -> bool:
        """
        Compare two requests based on priority, arrival time, and request ID.
        Used in priority scheduling.
        """
        if self.priority != other.priority:
            return self.priority < other.priority
        if self.arrival_time != other.arrival_time:
            return self.arrival_time < other.arrival_time
        if self.request_id != other.request_id:
            return self.request_id < other.request_id
        return id(self) < id(other)

287
288

class RequestStatus(enum.IntEnum):
289
    """Status of a request."""
290

291
292
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
293
    WAITING_FOR_REMOTE_KVS = enum.auto()
294
    WAITING_FOR_STREAMING_REQ = enum.auto()
295
296
297
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
298
    # as a finished status.
299
300
301
302
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
303
    FINISHED_ERROR = enum.auto()
304

305
    def __str__(self) -> str:
306
307
        return self.name

308
309
310
311
312
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
313
    def get_finished_reason(status: "RequestStatus") -> FinishReason | None:
314
315
316
317
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
318
# NOTE: The ignored requests are the requests whose prompt lengths
319
320
321
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
322
323
324
325
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
326
    RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
327
    RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
328
}