"tests/vscode:/vscode.git/clone" did not exist on "67841317d11653febdb321d7748f3d7e0c242d64"
request.py 5.56 KB
Newer Older
1
2
3
import enum
from typing import TYPE_CHECKING, List, Optional, Union

4
from vllm.inputs.data import DecoderOnlyInputs
5
from vllm.lora.request import LoRARequest
6
from vllm.multimodal import MultiModalKwargs
7
8
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
9
from vllm.v1.engine import EngineCoreRequest
10
from vllm.v1.utils import ConstantList
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

if TYPE_CHECKING:
    from vllm.inputs import DecoderOnlyInputs


class Request:

    def __init__(
        self,
        request_id: str,
        inputs: "DecoderOnlyInputs",
        sampling_params: SamplingParams,
        eos_token_id: Optional[int],
        arrival_time: float,
        lora_request: Optional[LoRARequest] = None,
    ) -> None:
        self.request_id = request_id
        self.inputs = inputs
        self.sampling_params = sampling_params
        # Because of LoRA, the eos token id can be different for each request.
        self.eos_token_id = eos_token_id
        self.metrics = RequestMetrics(arrival_time=arrival_time,
                                      last_token_time=arrival_time,
                                      first_scheduled_time=None,
                                      first_token_time=None,
                                      time_in_queue=None)
        self.lora_request = lora_request

        self.status = RequestStatus.WAITING
        self.stop_reason: Union[int, str, None] = None
        assert sampling_params.max_tokens is not None
        self.max_tokens = sampling_params.max_tokens

        self.prompt = inputs.get("prompt")
        self.prompt_token_ids = inputs["prompt_token_ids"]
        self.num_prompt_tokens = len(self.prompt_token_ids)
47
48
        self._output_token_ids: List[int] = []
        self._all_token_ids: List[int] = self.prompt_token_ids.copy()
49
50
        self.num_computed_tokens = 0

51
52
53
54
55
56
57
58
59
60
61
62
        # Raw multimodal data before the mm input mapper (e.g., PIL images).
        self.mm_data = inputs.get("multi_modal_data")
        self.mm_processor_kwargs = inputs.get("mm_processor_kwargs")
        mm_positions = inputs.get("multi_modal_placeholders")
        if mm_positions:
            # FIXME(woosuk): Support other modalities.
            self.mm_positions = mm_positions.get("image", [])
        else:
            self.mm_positions = []
        # Output of the mm input mapper (e.g., image tensors).
        self.mm_inputs: List[MultiModalKwargs] = []

63
64
65
66
    @classmethod
    def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
        return cls(
            request_id=request.request_id,
67
68
69
70
71
72
73
74
            inputs=DecoderOnlyInputs(
                type="token",
                prompt_token_ids=request.prompt_token_ids,
                prompt=request.prompt,
                multi_modal_data=request.mm_data,
                multi_modal_placeholders=request.mm_placeholders,
                mm_processor_kwargs=request.mm_processor_kwargs,
            ),
75
76
77
78
79
80
            sampling_params=request.sampling_params,
            eos_token_id=request.eos_token_id,
            arrival_time=request.arrival_time,
            lora_request=request.lora_request,
        )

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    @property
    def output_token_ids(self) -> ConstantList[int]:
        # Prevent directly appending to the output_token_ids since
        # all_token_ids should also be updated simultaneously.
        return ConstantList(self._output_token_ids)

    @property
    def all_token_ids(self) -> ConstantList[int]:
        # Prevent directly appending to the all_token_ids since
        # output_token_ids should also be updated simultaneously
        return ConstantList(self._all_token_ids)

    def append_output_token_ids(
        self,
        token_ids: Union[int, List[int]],
    ) -> None:
        if isinstance(token_ids, int):
            token_ids = [token_ids]
        self._output_token_ids.extend(token_ids)
        self._all_token_ids.extend(token_ids)

102
103
    @property
    def num_tokens(self) -> int:
104
        return len(self._all_token_ids)
105
106
107

    @property
    def num_output_tokens(self) -> int:
108
        return len(self._output_token_ids)
109
110
111
112
113
114
115

    def is_finished(self) -> bool:
        return RequestStatus.is_finished(self.status)

    def get_finished_reason(self) -> Union[str, None]:
        return RequestStatus.get_finished_reason(self.status)

116
117
118
119
120
121
122
123
124
125
126
127
    def has_encoder_inputs(self) -> bool:
        return self.mm_data is not None

    @property
    def num_encoder_inputs(self) -> int:
        return len(self.mm_positions)

    def get_num_encoder_tokens(self, input_id: int) -> int:
        assert input_id < len(self.mm_positions)
        num_tokens = self.mm_positions[input_id]["length"]
        return num_tokens

128
129

class RequestStatus(enum.IntEnum):
130
    """Status of a request."""
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    WAITING = 0
    RUNNING = 1
    PREEMPTED = 2
    # Note: anything after PREEMPTED (2) will be considered
    # as a finished status.
    FINISHED_STOPPED = 3
    FINISHED_LENGTH_CAPPED = 4
    FINISHED_ABORTED = 5
    FINISHED_IGNORED = 6

    @staticmethod
    def is_finished(status: "RequestStatus") -> bool:
        return status > RequestStatus.PREEMPTED

    @staticmethod
    def get_finished_reason(status: "RequestStatus") -> Union[str, None]:
        return _FINISHED_REASON_MAP.get(status)


# Mapping of finished statuses to their finish reasons.
151
# NOTE: The ignored requests are the requests whose prompt lengths
152
153
154
155
156
157
158
159
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
    RequestStatus.FINISHED_STOPPED: "stop",
    RequestStatus.FINISHED_LENGTH_CAPPED: "length",
    RequestStatus.FINISHED_ABORTED: "abort",
    RequestStatus.FINISHED_IGNORED: "length",
}