request.py 8.86 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections.abc import Callable, Mapping
7
from functools import partial
8
from typing import TYPE_CHECKING, Any, Optional
9

10
11
import torch

12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
17
18
19
20
21
from vllm.v1.engine import (
    EngineCoreEvent,
    EngineCoreEventType,
    EngineCoreRequest,
    FinishReason,
)
22
from vllm.v1.structured_output.request import StructuredOutputRequest
23
from vllm.v1.utils import ConstantList
24

25
if TYPE_CHECKING:
26
    from vllm.lora.request import LoRARequest
27
    from vllm.v1.core.kv_cache_utils import BlockHash
28

29
30
31
32
33

class Request:
    def __init__(
        self,
        request_id: str,
34
35
36
37
        prompt_token_ids: list[int] | None,
        sampling_params: SamplingParams | None,
        pooling_params: PoolingParams | None,
        eos_token_id: int | None,
38
        client_index: int = 0,
39
40
41
        arrival_time: float | None = None,
        prompt_embeds: torch.Tensor | None = None,
        mm_features: list[MultiModalFeatureSpec] | None = None,
42
43
        lora_request: Optional["LoRARequest"] = None,
        structured_output_request: Optional["StructuredOutputRequest"] = None,
44
        cache_salt: str | None = None,
45
        priority: int = 0,
46
47
        trace_headers: Mapping[str, str] | None = None,
        block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
48
49
    ) -> None:
        self.request_id = request_id
50
        self.client_index = client_index
51
        self.priority = priority
52
        self.sampling_params = sampling_params
53
        self.pooling_params = pooling_params
54
55
56
        # Because of LoRA, the eos token id can be different for each request.
        self.eos_token_id = eos_token_id
        self.lora_request = lora_request
57
        self.structured_output_request = structured_output_request
58
        self.arrival_time = arrival_time if arrival_time is not None else time.time()
59

60
        self.status = RequestStatus.WAITING
61
        self.use_structured_output = False
62
        self.events: list[EngineCoreEvent] = []
63
        self.stop_reason: int | str | None = None
64
65

        # P/D: Connector-specific KV transfer parameters.
66
        self.kv_transfer_params: dict[str, Any] | None = None
67
68

        if pooling_params is not None:
69
            # Pooling models.
70
71
            self.max_tokens = 1
        elif sampling_params is not None:
72
            # Generative models.
73
74
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
75
            if sampling_params.structured_outputs is not None:
76
                self.status = RequestStatus.WAITING_FOR_FSM
77
                self.use_structured_output = True
78
79

            if sampling_params.extra_args is not None:
80
81
82
                self.kv_transfer_params = sampling_params.extra_args.get(
                    "kv_transfer_params"
                )
83
        else:
84
            raise ValueError("sampling_params and pooling_params can't both be unset")
85

86
        self.prompt_token_ids = prompt_token_ids
87
88
        self.prompt_embeds = prompt_embeds
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
89
90
            prompt_token_ids, prompt_embeds
        )
91
        self._output_token_ids: list[int] = []
92
93
94
95
96
        self._all_token_ids: list[int] = (
            self.prompt_token_ids.copy()
            if self.prompt_token_ids is not None
            else [0] * self.num_prompt_tokens
        )
97
        self.num_output_placeholders = 0  # Used in async scheduling.
98
        self.spec_token_ids: list[int] = []
99
        self.num_computed_tokens = 0
100
        self.cache_salt: str | None = cache_salt
101

102
        # Multi-modal related
103
104
        self.mm_features = mm_features or []
        self.num_encoder_inputs = len(self.mm_features)
105
        self.has_encoder_inputs = self.num_encoder_inputs > 0
106

107
        # Read-only views
omahs's avatar
omahs committed
108
        # Prevent directly appending to these lists since
109
110
111
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)
112
113
        # trace_headers
        self.trace_headers = trace_headers
114
115
116
117
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

118
119
120
121
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

122
123
124
        # The number of requests being preempted by the scheduler
        self.num_preemptions = 0

125
        self.block_hashes: list[BlockHash] = []
126
        self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
127
128
129
130
        if block_hasher is not None:
            self.get_hash_new_full_blocks = partial(block_hasher, self)
            self.block_hashes = self.get_hash_new_full_blocks()

131
    @classmethod
132
    def from_engine_core_request(
133
134
        cls,
        request: EngineCoreRequest,
135
        block_hasher: Callable[["Request"], list["BlockHash"]] | None,
136
    ) -> "Request":
137
138
        return cls(
            request_id=request.request_id,
139
            client_index=request.client_index,
140
            prompt_token_ids=request.prompt_token_ids,
141
            prompt_embeds=request.prompt_embeds,
142
            mm_features=request.mm_features,
143
            sampling_params=request.sampling_params,
144
            pooling_params=request.pooling_params,
145
            eos_token_id=request.eos_token_id,
146
            arrival_time=request.arrival_time,
147
            lora_request=request.lora_request,
148
            structured_output_request=StructuredOutputRequest(
149
150
151
152
                sampling_params=request.sampling_params
            )
            if request.sampling_params
            else None,
153
            cache_salt=request.cache_salt,
154
            priority=request.priority,
155
            trace_headers=request.trace_headers,
156
            block_hasher=block_hasher,
157
158
        )

159
160
    def append_output_token_ids(
        self,
161
        token_ids: int | list[int],
162
163
    ) -> None:
        if isinstance(token_ids, int):
164
165
166
167
168
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
169

170
171
172
        if self.get_hash_new_full_blocks is not None:
            self.block_hashes.extend(self.get_hash_new_full_blocks())

173
174
175
176
    @property
    def is_output_corrupted(self) -> bool:
        return self.num_nans_in_logits > 0

177
178
    @property
    def num_tokens(self) -> int:
179
        return len(self._all_token_ids)
180

181
182
183
184
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

185
186
    @property
    def num_output_tokens(self) -> int:
187
        return len(self._output_token_ids)
188
189
190
191

    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

192
    def get_finished_reason(self) -> FinishReason | None:
193
194
        return RequestStatus.get_finished_reason(self.status)

195
    def get_num_encoder_tokens(self, input_id: int) -> int:
196
197
        assert input_id < len(self.mm_features)
        num_tokens = self.mm_features[input_id].mm_position.length
198
199
        return num_tokens

200
201
202
    def record_event(
        self,
        event_type: EngineCoreEventType,
203
        timestamp: float | None = None,
204
205
206
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

207
    def take_events(self) -> list[EngineCoreEvent] | None:
208
209
210
211
212
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

213
214

class RequestStatus(enum.IntEnum):
215
    """Status of a request."""
216

217
218
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
219
    WAITING_FOR_REMOTE_KVS = enum.auto()
220
221
222
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
223
    # as a finished status.
224
225
226
227
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
228

229
230
231
    def __str__(self):
        return self.name

232
233
234
235
236
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
237
    def get_finished_reason(status: "RequestStatus") -> FinishReason | None:
238
239
240
241
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
242
# NOTE: The ignored requests are the requests whose prompt lengths
243
244
245
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
246
247
248
249
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
250
}