request.py 8.91 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections.abc import Mapping
7
8
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
9

10
11
import torch

12
from vllm.multimodal.inputs import MultiModalFeatureSpec
13
from vllm.pooling_params import PoolingParams
14
from vllm.sampling_params import SamplingParams
15
from vllm.utils import length_from_prompt_token_ids_or_embeds
16
17
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
                            EngineCoreRequest, FinishReason)
18
from vllm.v1.structured_output.request import StructuredOutputRequest
19
from vllm.v1.utils import ConstantList
20

21
if TYPE_CHECKING:
22
    from vllm.lora.request import LoRARequest
23
    from vllm.v1.core.kv_cache_utils import BlockHash
24

25
26
27
28
29
30

class Request:

    def __init__(
        self,
        request_id: str,
31
        prompt_token_ids: Optional[list[int]],
32
33
        sampling_params: Optional[SamplingParams],
        pooling_params: Optional[PoolingParams],
34
        eos_token_id: Optional[int],
35
        client_index: int = 0,
36
        arrival_time: Optional[float] = None,
37
        prompt_embeds: Optional[torch.Tensor] = None,
38
        mm_features: Optional[list[MultiModalFeatureSpec]] = None,
39
40
        lora_request: Optional["LoRARequest"] = None,
        structured_output_request: Optional["StructuredOutputRequest"] = None,
41
        cache_salt: Optional[str] = None,
42
        priority: int = 0,
43
        trace_headers: Optional[Mapping[str, str]] = None,
44
45
        block_hasher: Optional[Callable[["Request"],
                                        list["BlockHash"]]] = None,
46
47
    ) -> None:
        self.request_id = request_id
48
        self.client_index = client_index
49
        self.priority = priority
50
        self.sampling_params = sampling_params
51
        self.pooling_params = pooling_params
52
53
54
        # Because of LoRA, the eos token id can be different for each request.
        self.eos_token_id = eos_token_id
        self.lora_request = lora_request
55
        self.structured_output_request = structured_output_request
56
57
        self.arrival_time = arrival_time if arrival_time is not None else \
            time.time()
58

59
        self.status = RequestStatus.WAITING
60
        self.use_structured_output = False
61
        self.events: list[EngineCoreEvent] = []
62
        self.stop_reason: Union[int, str, None] = None
63
64
65
66
67

        # P/D: Connector-specific KV transfer parameters.
        self.kv_transfer_params: Optional[dict[str, Any]] = None

        if pooling_params is not None:
68
            # Pooling models.
69
70
            self.max_tokens = 1
        elif sampling_params is not None:
71
            # Generative models.
72
73
            assert sampling_params.max_tokens is not None
            self.max_tokens = sampling_params.max_tokens
74
            if sampling_params.structured_outputs is not None:
75
                self.status = RequestStatus.WAITING_FOR_FSM
76
                self.use_structured_output = True
77
78
79
80
81
82
83

            if sampling_params.extra_args is not None:
                self.kv_transfer_params = \
                    sampling_params.extra_args.get("kv_transfer_params")
        else:
            raise ValueError(
                "sampling_params and pooling_params can't both be unset")
84

85
        self.prompt_token_ids = prompt_token_ids
86
87
88
        self.prompt_embeds = prompt_embeds
        self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
            prompt_token_ids, prompt_embeds)
89
        self._output_token_ids: list[int] = []
90
91
92
        self._all_token_ids: list[int] = self.prompt_token_ids.copy(
        ) if self.prompt_token_ids is not None else [0
                                                     ] * self.num_prompt_tokens
93
        self.num_output_placeholders = 0  # Used in async scheduling.
94
        self.spec_token_ids: list[int] = []
95
        self.num_computed_tokens = 0
96
        self.cache_salt: Optional[str] = cache_salt
97

98
        # Multi-modal related
99
100
        self.mm_features = mm_features or []
        self.num_encoder_inputs = len(self.mm_features)
101
        self.has_encoder_inputs = self.num_encoder_inputs > 0
102

103
        # Read-only views
omahs's avatar
omahs committed
104
        # Prevent directly appending to these lists since
105
106
107
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)
108
109
        # trace_headers
        self.trace_headers = trace_headers
110
111
112
113
        # State
        # The number of tokens with prefix cache hits.
        self.num_cached_tokens = -1

114
115
116
117
        # The number of NaNs in logits. A value greater than 0
        # indicates that the output is corrupted
        self.num_nans_in_logits = 0

118
119
120
121
122
123
124
        self.block_hashes: list[BlockHash] = []
        self.get_hash_new_full_blocks: Optional[Callable[
            [], list[BlockHash]]] = None
        if block_hasher is not None:
            self.get_hash_new_full_blocks = partial(block_hasher, self)
            self.block_hashes = self.get_hash_new_full_blocks()

125
    @classmethod
126
127
128
129
    def from_engine_core_request(
        cls, request: EngineCoreRequest,
        block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
    ) -> "Request":
130
131
        return cls(
            request_id=request.request_id,
132
            client_index=request.client_index,
133
            prompt_token_ids=request.prompt_token_ids,
134
            prompt_embeds=request.prompt_embeds,
135
            mm_features=request.mm_features,
136
            sampling_params=request.sampling_params,
137
            pooling_params=request.pooling_params,
138
            eos_token_id=request.eos_token_id,
139
            arrival_time=request.arrival_time,
140
            lora_request=request.lora_request,
141
            structured_output_request=StructuredOutputRequest(
142
143
                sampling_params=request.sampling_params) \
                    if request.sampling_params else None,
144
            cache_salt=request.cache_salt,
145
            priority=request.priority,
146
            trace_headers=request.trace_headers,
147
            block_hasher=block_hasher,
148
149
        )

150
151
    def append_output_token_ids(
        self,
152
        token_ids: Union[int, list[int]],
153
154
    ) -> None:
        if isinstance(token_ids, int):
155
156
157
158
159
            self._output_token_ids.append(token_ids)
            self._all_token_ids.append(token_ids)
        else:
            self._output_token_ids.extend(token_ids)
            self._all_token_ids.extend(token_ids)
160

161
162
163
        if self.get_hash_new_full_blocks is not None:
            self.block_hashes.extend(self.get_hash_new_full_blocks())

164
165
166
167
    @property
    def is_output_corrupted(self) -> bool:
        return self.num_nans_in_logits > 0

168
169
    @property
    def num_tokens(self) -> int:
170
        return len(self._all_token_ids)
171

172
173
174
175
    @property
    def num_tokens_with_spec(self) -> int:
        return len(self._all_token_ids) + len(self.spec_token_ids)

176
177
    @property
    def num_output_tokens(self) -> int:
178
        return len(self._output_token_ids)
179
180
181
182

    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

183
    def get_finished_reason(self) -> Union[FinishReason, None]:
184
185
        return RequestStatus.get_finished_reason(self.status)

186
    def get_num_encoder_tokens(self, input_id: int) -> int:
187
188
        assert input_id < len(self.mm_features)
        num_tokens = self.mm_features[input_id].mm_position.length
189
190
        return num_tokens

191
192
193
194
195
196
197
198
199
200
201
202
203
    def record_event(
        self,
        event_type: EngineCoreEventType,
        timestamp: Optional[float] = None,
    ) -> None:
        self.events.append(EngineCoreEvent.new_event(event_type, timestamp))

    def take_events(self) -> Optional[list[EngineCoreEvent]]:
        if not self.events:
            return None
        events, self.events = self.events, []
        return events

204
205

class RequestStatus(enum.IntEnum):
206
    """Status of a request."""
207
208
    WAITING = enum.auto()
    WAITING_FOR_FSM = enum.auto()
Robert Shaw's avatar
Robert Shaw committed
209
    WAITING_FOR_REMOTE_KVS = enum.auto()
210
211
212
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
213
    # as a finished status.
214
215
216
217
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
218

219
220
221
    def __str__(self):
        return self.name

222
223
224
225
226
    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
227
    def get_finished_reason(
228
            status: "RequestStatus") -> Union[FinishReason, None]:
229
230
231
232
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
233
# NOTE: The ignored requests are the requests whose prompt lengths
234
235
236
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
237
238
239
240
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
241
}