request.py 5.93 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import enum
4
from typing import TYPE_CHECKING, List, Optional, Union
5
6
7
8

from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
9
from vllm.v1.engine import EngineCoreRequest, RequestFinishedReason
10
from vllm.v1.utils import ConstantList
11

12
if TYPE_CHECKING:
13
14
    from vllm.multimodal import MultiModalKwargs
    from vllm.multimodal.inputs import PlaceholderRange
15
16
    from vllm.v1.core.kv_cache_utils import BlockHashType

17
18
19
20
21
22

class Request:

    def __init__(
        self,
        request_id: str,
23
24
25
26
27
        prompt: Optional[str],
        prompt_token_ids: List[int],
        multi_modal_inputs: Optional[List["MultiModalKwargs"]],
        multi_modal_hashes: Optional[List[str]],
        multi_modal_placeholders: Optional[List["PlaceholderRange"]],
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        sampling_params: SamplingParams,
        eos_token_id: Optional[int],
        arrival_time: float,
        lora_request: Optional[LoRARequest] = None,
    ) -> None:
        self.request_id = request_id
        self.sampling_params = sampling_params
        # Because of LoRA, the eos token id can be different for each request.
        self.eos_token_id = eos_token_id
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
        self.lora_request = lora_request

        self.status = RequestStatus.WAITING
        self.stop_reason: Union[int, str, None] = None
        assert sampling_params.max_tokens is not None
        self.max_tokens = sampling_params.max_tokens

49
50
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
51
        self.num_prompt_tokens = len(self.prompt_token_ids)
52
53
        self._output_token_ids: List[int] = []
        self._all_token_ids: List[int] = self.prompt_token_ids.copy()
54
55
        self.num_computed_tokens = 0

56
57
58
59
        # Multi-modal related
        self.mm_positions = multi_modal_placeholders or []
        self.mm_inputs = multi_modal_inputs or []
        self.mm_hashes: List[str] = multi_modal_hashes or []
60

61
62
        # Sanity check
        assert len(self.mm_inputs) == len(self.mm_positions)
63
64
        if self.mm_hashes:
            assert len(self.mm_inputs) == len(self.mm_hashes)
65
66
67
68

        # Cache the computed kv block hashes of the request to avoid
        # recomputing.
        self._kv_block_hashes: List[BlockHashType] = []
69
        self.kv_block_hashes = ConstantList(self._kv_block_hashes)
70

71
72
73
74
75
76
        # Read-only views
        # Prevent directly appending to the these lists since
        # they should also be updated simultaneously.
        self.output_token_ids = ConstantList(self._output_token_ids)
        self.all_token_ids = ConstantList(self._all_token_ids)

77
78
79
80
    @classmethod
    def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
        return cls(
            request_id=request.request_id,
81
82
83
84
85
            prompt=request.prompt,
            prompt_token_ids=request.prompt_token_ids,
            multi_modal_inputs=request.mm_inputs,
            multi_modal_hashes=request.mm_hashes,
            multi_modal_placeholders=request.mm_placeholders,
86
87
88
89
90
91
            sampling_params=request.sampling_params,
            eos_token_id=request.eos_token_id,
            arrival_time=request.arrival_time,
            lora_request=request.lora_request,
        )

92
93
94
95
96
97
98
99
100
    def append_output_token_ids(
        self,
        token_ids: Union[int, List[int]],
    ) -> None:
        if isinstance(token_ids, int):
            token_ids = [token_ids]
        self._output_token_ids.extend(token_ids)
        self._all_token_ids.extend(token_ids)

101
102
    @property
    def num_tokens(self) -> int:
103
        return len(self._all_token_ids)
104
105
106

    @property
    def num_output_tokens(self) -> int:
107
        return len(self._output_token_ids)
108
109
110
111

    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

112
    def get_finished_reason(self) -> Union[RequestFinishedReason, None]:
113
114
        return RequestStatus.get_finished_reason(self.status)

115
    def has_encoder_inputs(self) -> bool:
116
        return len(self.mm_inputs) > 0
117
118
119
120
121
122
123
124
125
126

    @property
    def num_encoder_inputs(self) -> int:
        return len(self.mm_positions)

    def get_num_encoder_tokens(self, input_id: int) -> int:
        assert input_id < len(self.mm_positions)
        num_tokens = self.mm_positions[input_id]["length"]
        return num_tokens

127
128
    def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
        self._kv_block_hashes = value
129
        self.kv_block_hashes = ConstantList(self._kv_block_hashes)
130
131
132
133

    def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
        self._kv_block_hashes.append(block_hash)

134
135

class RequestStatus(enum.IntEnum):
136
    """Status of a request."""
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    WAITING = 0
    RUNNING = 1
    PREEMPTED = 2
    # Note: anything after PREEMPTED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6

    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
152
153
    def get_finished_reason(
            status: "RequestStatus") -> Union[RequestFinishedReason, None]:
154
155
156
157
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
158
# NOTE: The ignored requests are the requests whose prompt lengths
159
160
161
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
162
163
164
165
    RequestStatus.FINISHED_STOPPED: RequestFinishedReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: RequestFinishedReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: RequestFinishedReason.ABORT,
    RequestStatus.FINISHED_IGNORED: RequestFinishedReason.LENGTH,
166
}