__init__.py 6.26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections.abc import Mapping
7
from typing import Any
8
9

import msgspec
10
import torch
11

12
from vllm.lora.request import LoRARequest
13
from vllm.multimodal.inputs import MultiModalFeatureSpec
14
from vllm.pooling_params import PoolingParams
15
from vllm.sampling_params import SamplingParams
16
from vllm.v1.metrics.stats import SchedulerStats
17
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
18
from vllm.v1.serial_utils import UtilityResult
19

20
21
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
22
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
23

24
25

class FinishReason(enum.IntEnum):
26
    """
27
    Reason a request finished - stop, length, abort, or error.
28

29
30
    Int rather than Str for more compact serialization.

31
32
    stop - a stop string was emitted
    length - max_tokens was consumed, or max_model_len was reached
33
34
35
    abort - aborted by client
    error - retryable request-level internal error (e.g., KV load failure).
            Invariant: always converted to 500 Internal Server Error.
36
37

    """
38

39
40
41
    STOP = 0
    LENGTH = 1
    ABORT = 2
42
    ERROR = 3
43
44

    def __str__(self):
45
        return FINISH_REASON_STRINGS[self.value]
46
47


48
class EngineCoreRequest(
49
50
51
52
53
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
54
    request_id: str
55
56
57
58
59
    prompt_token_ids: list[int] | None
    mm_features: list[MultiModalFeatureSpec] | None
    sampling_params: SamplingParams | None
    pooling_params: PoolingParams | None
    eos_token_id: int | None
60
    arrival_time: float
61
62
63
64
    lora_request: LoRARequest | None
    cache_salt: str | None
    data_parallel_rank: int | None
    prompt_embeds: torch.Tensor | None = None
65

66
67
68
69
    # Index of the client, used to ensure outputs are sent back to the same
    # client for this request when scaling out the front-end.
    client_index: int = 0

70
71
72
73
    # Used in DP case to indicate which wave of requests this is expected to
    # belong to, to cover a race condition where the request is sent before
    # a wave finished notification is received.
    current_wave: int = 0
74
    priority: int = 0
75

76
    trace_headers: Mapping[str, str] | None = None
77

78
79
80
81
82
83
84
85
    @property
    def params(self) -> SamplingParams | PoolingParams:
        """Return the processed params (sampling or pooling)."""
        if self.sampling_params is not None:
            return self.sampling_params
        assert self.pooling_params is not None
        return self.pooling_params

86

87
88
class EngineCoreEventType(enum.IntEnum):
    """The type of engine core request event."""
89

90
91
    QUEUED = 1
    SCHEDULED = 2
92
    PREEMPTED = 3
93
94
95
96
97
98
99
100
101


class EngineCoreEvent(msgspec.Struct):
    """A timestamped engine core event associated with a request.

    The timestamp is a monotonic timestamps and is used for by the engine
    frontend to calculate intervals between engine core events. These
    timestamps should not be compared with timestamps from other processes.
    """
102

103
104
105
106
    type: EngineCoreEventType
    timestamp: float

    @classmethod
107
    def new_event(
108
        cls, event_type: EngineCoreEventType, timestamp: float | None = None
109
    ) -> "EngineCoreEvent":
110
111
112
113
        timestamp = time.monotonic() if timestamp is None else timestamp
        return cls(event_type, timestamp)


114
class EngineCoreOutput(
115
116
117
118
119
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
120
    request_id: str
121
    new_token_ids: list[int]
122

123
124
    new_logprobs: LogprobsLists | None = None
    new_prompt_logprobs_tensors: LogprobsTensors | None = None
125

126
    pooling_output: torch.Tensor | None = None
127

128
129
130
131
    finish_reason: FinishReason | None = None
    stop_reason: int | str | None = None
    events: list[EngineCoreEvent] | None = None
    kv_transfer_params: dict[str, Any] | None = None
132

133
    trace_headers: Mapping[str, str] | None = None
134
135
136
    # The number of tokens with prefix cache hits.
    num_cached_tokens: int = 0

137
138
139
140
    # The number of NaNs in logits.
    # A value greater than 0 indicates that the output is corrupted.
    num_nans_in_logits: int = 0

141
142
143
144
    @property
    def finished(self) -> bool:
        return self.finish_reason is not None

145

146
class UtilityOutput(
147
148
149
150
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
151
152
153
    call_id: int

    # Non-None implies the call failed, result should be None.
154
155
    failure_message: str | None = None
    result: UtilityResult | None = None
156
157


158
class EngineCoreOutputs(
159
160
161
162
163
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
164
    # NOTE(Nick): We could consider ways to make this more compact,
165
    # e.g. columnwise layout
166

167
168
    engine_index: int = 0

169
    # [num_reqs]
170
    outputs: list[EngineCoreOutput] = []
171
    scheduler_stats: SchedulerStats | None = None
172
173
    timestamp: float = 0.0

174
175
    utility_output: UtilityOutput | None = None
    finished_requests: set[str] | None = None
176

177
178
    # In DP case, used to signal that the current wave of requests
    # has finished and the engines are paused.
179
    wave_complete: int | None = None
180
181
    # In DP case, used to signal that a request was received for an
    # "old" wave, so the next wave needs to be started in other engines.
182
    start_wave: int | None = None
183

184
185
186
    def __post_init__(self):
        if self.timestamp == 0.0:
            self.timestamp = time.monotonic()
187
188
189
190
191
192
193


class EngineCoreRequestType(enum.Enum):
    """
    Request types defined as hex byte strings, so it can be sent over sockets
    without separate encoding step.
    """
194
195
196
197
198

    ADD = b"\x00"
    ABORT = b"\x01"
    START_DP_WAVE = b"\x02"
    UTILITY = b"\x03"
199
    # Sentinel used within EngineCoreProc.
200
    EXECUTOR_FAILED = b"\x04"
201
202
203
204
205
206
207
208
209
210
211
212
213
214


class ReconfigureDistributedRequest(msgspec.Struct):
    new_data_parallel_size: int
    new_data_parallel_rank: int
    new_data_parallel_rank_local: int
    new_data_parallel_master_ip: str
    new_data_parallel_master_port: int


class ReconfigureRankType(enum.IntEnum):
    """
    Rank type for reconfiguring distributed request.
    """
215

216
217
    KEEP_CURRENT_RANK = -1
    SHUTDOWN_CURRENT_RANK = -2