request.py 9.97 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections.abc import Callable, Mapping
7
from functools import partial
8
from typing import TYPE_CHECKING, Any, Optional
9

10
11
import torch

12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
17
18
19
20
21
from vllm.v1.engine import (
    EngineCoreEvent,
    EngineCoreEventType,
    EngineCoreRequest,
    FinishReason,
)
22
from vllm.v1.structured_output.request import StructuredOutputRequest
23
from vllm.v1.utils import ConstantList
24

25
if TYPE_CHECKING:
26
    from vllm.lora.request import LoRARequest
27
    from vllm.v1.core.kv_cache_utils import BlockHash
28

29
30
31
32
33

class Request:
    def __init__(
        self,
        request_id: str,
34
35
36
37
        prompt_token_ids: list[int] | None,
        sampling_params: SamplingParams | None,
        pooling_params: PoolingParams | None,
        eos_token_id: int | None,
38
        client_index: int = 0,
39
40
41
        arrival_time: float | None = None,
        prompt_embeds: torch.Tensor | None = None,
        mm_features: list[MultiModalFeatureSpec] | None = None,
42
        lora_request: Optional["LoRARequest"] = None,
43
        cache_salt: str | None = None,
44
        priority: int = 0,
45
46
        trace_headers: Mapping[str, str] | None = None,
        block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
47
48
    ) -> None:
        self.request_id = request_id
49
        self.client_index = client_index
50
        self.priority = priority
51
        self.sampling_params = sampling_params
52
        self.pooling_params = pooling_params
53
54
55
        # Because of LoRA, the eos token id can be different for each request.
        self.eos_token_id = eos_token_id
        self.lora_request = lora_request
56
57
58
        self.structured_output_request = StructuredOutputRequest.from_sampling_params(
            sampling_params
        )
59
        self.arrival_time = arrival_time if arrival_time is not None else time.time()
60

61
        self.status = RequestStatus.WAITING
62
        self.events: list[EngineCoreEvent] = []
63
        self.stop_reason: int | str | None = None
64
65

        # P/D: Connector-specific KV transfer parameters.
66
        self.kv_transfer_params: dict[str, Any] | None = None
67
68

        if pooling_params is not None:
69
            # Pooling models.
70
71
            self.max_tokens = 1
        elif sampling_params is not None:
72
            # Generative models.
73
74
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
75
            if self.structured_output_request is not None:
76
77
78
                self.status = RequestStatus.WAITING_FOR_FSM

            if sampling_params.extra_args is not None:
79
80
81
                self.kv_transfer_params = sampling_params.extra_args.get(
                    "kv_transfer_params"
                )
82
        else:
83
            raise ValueError("sampling_params and pooling_params can't both be unset")
84

85
        self.prompt_token_ids = prompt_token_ids
86
87
        self.prompt_embeds = prompt_embeds
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
88
89
            prompt_token_ids, prompt_embeds
        )
90
        self._output_token_ids: list[int] = []
91
92
93
94
95
        self._all_token_ids: list[int] = (
            self.prompt_token_ids.copy()
            if self.prompt_token_ids is not None
            else [0] * self.num_prompt_tokens
        )
96
97
98
99
100
101

        # Used in async scheduling.
        self.num_output_placeholders = 0
        # Used in forced preemption (reset_prefix_cache) with async scheduling.
        self.discard_latest_async_tokens = False

102
        self.spec_token_ids: list[int] = []
103
        self.num_computed_tokens = 0
104
        self.cache_salt: str | None = cache_salt
105

106
        # Multi-modal related
107
108
        self.mm_features = mm_features or []
        self.num_encoder_inputs = len(self.mm_features)
109
        self.has_encoder_inputs = self.num_encoder_inputs > 0
110

111
        # Read-only views
omahs's avatar
omahs committed
112
        # Prevent directly appending to these lists since
113
114
115
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)
116
117
        # trace_headers
        self.trace_headers = trace_headers
118
119
120
121
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

122
123
124
125
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

126
        # The number of times this request has been preempted by the scheduler.
127
128
        self.num_preemptions = 0

129
130
131
        # The number of tokens that have been computed remotely.
        self.num_external_computed_tokens = 0

132
        self.block_hashes: list[BlockHash] = []
133
        self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
134
135
136
137
        if block_hasher is not None:
            self.get_hash_new_full_blocks = partial(block_hasher, self)
            self.block_hashes = self.get_hash_new_full_blocks()

138
139
        self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()

140
    @classmethod
141
    def from_engine_core_request(
142
143
        cls,
        request: EngineCoreRequest,
144
        block_hasher: Callable[["Request"], list["BlockHash"]] | None,
145
    ) -> "Request":
146
147
        return cls(
            request_id=request.request_id,
148
            client_index=request.client_index,
149
            prompt_token_ids=request.prompt_token_ids,
150
            prompt_embeds=request.prompt_embeds,
151
            mm_features=request.mm_features,
152
            sampling_params=request.sampling_params,
153
            pooling_params=request.pooling_params,
154
            eos_token_id=request.eos_token_id,
155
            arrival_time=request.arrival_time,
156
            lora_request=request.lora_request,
157
            cache_salt=request.cache_salt,
158
            priority=request.priority,
159
            trace_headers=request.trace_headers,
160
            block_hasher=block_hasher,
161
162
        )

163
164
    def append_output_token_ids(
        self,
165
        token_ids: int | list[int],
166
167
    ) -> None:
        if isinstance(token_ids, int):
168
169
170
171
172
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
173

174
175
176
        if self.get_hash_new_full_blocks is not None:
            self.block_hashes.extend(self.get_hash_new_full_blocks())

177
178
179
180
    @property
    def use_structured_output(self) -> bool:
        return self.structured_output_request is not None

181
182
    @property
    def num_tokens(self) -> int:
183
        return len(self._all_token_ids)
184

185
186
187
188
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

189
190
    @property
    def num_output_tokens(self) -> int:
191
        return len(self._output_token_ids)
192

193
194
195
196
197
198
199
200
201
202
203
204
205
    def get_skip_reading_prefix_cache(self) -> bool:
        if (
            self.sampling_params is not None
            and self.sampling_params.skip_reading_prefix_cache is not None
        ):
            return self.sampling_params.skip_reading_prefix_cache
        elif (
            self.pooling_params is not None
            and self.pooling_params.skip_reading_prefix_cache is not None
        ):
            return self.pooling_params.skip_reading_prefix_cache
        return False

206
207
208
    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

209
    def get_finished_reason(self) -> FinishReason | None:
210
211
        return RequestStatus.get_finished_reason(self.status)

212
    def get_num_encoder_embeds(self, input_id: int) -> int:
213
        assert input_id < len(self.mm_features)
214
        return self.mm_features[input_id].mm_position.get_num_embeds
215

216
217
218
    def record_event(
        self,
        event_type: EngineCoreEventType,
219
        timestamp: float | None = None,
220
221
222
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

223
    def take_events(self) -> list[EngineCoreEvent] | None:
224
225
226
227
228
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

229
230
231
232
233
234
235
236
237
238
239
240
241
    def __lt__(self, other: "Request") -> bool:
        """
        Compare two requests based on priority, arrival time, and request ID.
        Used in priority scheduling.
        """
        if self.priority != other.priority:
            return self.priority < other.priority
        if self.arrival_time != other.arrival_time:
            return self.arrival_time < other.arrival_time
        if self.request_id != other.request_id:
            return self.request_id < other.request_id
        return id(self) < id(other)

242
243

class RequestStatus(enum.IntEnum):
244
    """Status of a request."""
245

246
247
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
248
    WAITING_FOR_REMOTE_KVS = enum.auto()
249
250
251
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
252
    # as a finished status.
253
254
255
256
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
257
    FINISHED_ERROR = enum.auto()
258

259
260
261
    def __str__(self):
        return self.name

262
263
264
265
266
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
267
    def get_finished_reason(status: "RequestStatus") -> FinishReason | None:
268
269
270
271
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
272
# NOTE: The ignored requests are the requests whose prompt lengths
273
274
275
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
276
277
278
279
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
280
    RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
281
}