__init__.py 8.23 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections.abc import Mapping
7
from typing import Any, Literal
8
9

import msgspec
10
import numpy as np
11
import torch
12
from typing_extensions import deprecated
13

14
from vllm.lora.request import LoRARequest
15
from vllm.multimodal.inputs import MultiModalFeatureSpec
16
from vllm.pooling_params import PoolingParams
17
from vllm.sampling_params import SamplingParams
18
from vllm.v1.metrics.stats import SchedulerStats
19
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
20
from vllm.v1.serial_utils import UtilityResult
21

22
23
24
25
26
27
# Type for pause_generation mode parameter.
# - "abort": Abort all in-flight requests immediately (default).
# - "wait": Wait for in-flight requests to complete before pausing.
# - "keep": Freeze requests in queue; they resume on resume_generation().
PauseMode = Literal["abort", "wait", "keep"]

28
29
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
30
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error", "repetition")
31

32
33
34
35
36
37
38
39
40
EEP_NOTIFICATION_CALL_ID = -1


class EEPNotificationType(enum.Enum):
    NEW_CORE_ENGINES_INIT_READY = "NEW_CORE_ENGINES_INIT_READY"
    NEW_CORE_ENGINES_WEIGHTS_INIT_READY = "NEW_CORE_ENGINES_WEIGHTS_INIT_READY"
    RECONFIGURE_FINISHED = "RECONFIGURE_FINISHED"
    SHUTDOWN_COMPLETE = "SHUTDOWN_COMPLETE"

41
42

class FinishReason(enum.IntEnum):
43
    """
44
    Reason a request finished - stop, length, abort, error, or repetition.
45

46
47
    Int rather than Str for more compact serialization.

48
49
    stop - a stop string was emitted
    length - max_tokens was consumed, or max_model_len was reached
50
51
52
    abort - aborted by client
    error - retryable request-level internal error (e.g., KV load failure).
            Invariant: always converted to 500 Internal Server Error.
53
    repetition - repetitive token pattern detected (hallucination)
54
55

    """
56

57
58
59
    STOP = 0
    LENGTH = 1
    ABORT = 2
60
    ERROR = 3
61
    REPETITION = 4
62
63

    def __str__(self):
64
        return FINISH_REASON_STRINGS[self.value]
65
66


67
class EngineCoreRequest(
68
69
70
71
72
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
73
    request_id: str
74
75
76
77
    prompt_token_ids: list[int] | None
    mm_features: list[MultiModalFeatureSpec] | None
    sampling_params: SamplingParams | None
    pooling_params: PoolingParams | None
78
    arrival_time: float
79
80
81
82
    lora_request: LoRARequest | None
    cache_salt: str | None
    data_parallel_rank: int | None
    prompt_embeds: torch.Tensor | None = None
83

84
85
86
87
    # Index of the client, used to ensure outputs are sent back to the same
    # client for this request when scaling out the front-end.
    client_index: int = 0

88
89
90
91
    # Used in DP case to indicate which wave of requests this is expected to
    # belong to, to cover a race condition where the request is sent before
    # a wave finished notification is received.
    current_wave: int = 0
92
    priority: int = 0
93

94
    trace_headers: Mapping[str, str] | None = None
95
    resumable: bool = False
96

97
98
99
100
101
102
    # The user-provided request ID. This field is set internally,
    # copied from the provided request_id that's originally assigned
    # to the request_id field, see InputProcessor.assign_request_id().
    # Used in outputs and to support abort(req_id, internal=False).
    external_req_id: str | None = None

103
104
    reasoning_ended: bool | None = None

105
106
107
108
109
110
111
112
    @property
    def params(self) -> SamplingParams | PoolingParams:
        """Return the processed params (sampling or pooling)."""
        if self.sampling_params is not None:
            return self.sampling_params
        assert self.pooling_params is not None
        return self.pooling_params

113
114
115
116
117
118
119
120
121
122
123
    @property
    @deprecated(
        "EngineCoreRequest.eos_token_id will be removed in v0.18. "
        "Please use EngineCoreRequest.sampling_params.eos_token_id instead."
    )
    def eos_token_id(self) -> int | None:
        if self.sampling_params is None:
            return None

        return self.sampling_params.eos_token_id

124

125
126
class EngineCoreEventType(enum.IntEnum):
    """The type of engine core request event."""
127

128
129
    QUEUED = 1
    SCHEDULED = 2
130
    PREEMPTED = 3
131
132
133
134
135
136
137
138
139


class EngineCoreEvent(msgspec.Struct):
    """A timestamped engine core event associated with a request.

    The timestamp is a monotonic timestamps and is used for by the engine
    frontend to calculate intervals between engine core events. These
    timestamps should not be compared with timestamps from other processes.
    """
140

141
142
143
144
    type: EngineCoreEventType
    timestamp: float

    @classmethod
145
    def new_event(
146
        cls, event_type: EngineCoreEventType, timestamp: float | None = None
147
    ) -> "EngineCoreEvent":
148
149
150
151
        timestamp = time.monotonic() if timestamp is None else timestamp
        return cls(event_type, timestamp)


152
class EngineCoreOutput(
153
154
155
156
157
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
158
    request_id: str
159
    new_token_ids: list[int]
160

161
162
    new_logprobs: LogprobsLists | None = None
    new_prompt_logprobs_tensors: LogprobsTensors | None = None
163

164
    pooling_output: torch.Tensor | None = None
165

166
167
168
169
    finish_reason: FinishReason | None = None
    stop_reason: int | str | None = None
    events: list[EngineCoreEvent] | None = None
    kv_transfer_params: dict[str, Any] | None = None
170

171
    trace_headers: Mapping[str, str] | None = None
172
    # The number of tokens with prefix cache hits (local + external).
173
    num_cached_tokens: int = 0
174
175
    # The number of tokens computed remotely (original count from connector).
    num_external_computed_tokens: int = 0
176
    routed_experts: np.ndarray | None = None
177
178
179
180
    # The number of NaNs in logits.
    # A value greater than 0 indicates that the output is corrupted.
    num_nans_in_logits: int = 0

181
182
183
184
    @property
    def finished(self) -> bool:
        return self.finish_reason is not None

185

186
class UtilityOutput(
187
188
189
190
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
191
192
193
    call_id: int

    # Non-None implies the call failed, result should be None.
194
195
    failure_message: str | None = None
    result: UtilityResult | None = None
196
197


198
class EngineCoreOutputs(
199
200
201
202
203
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
204
    # NOTE(Nick): We could consider ways to make this more compact,
205
    # e.g. columnwise layout
206

207
208
    engine_index: int = 0

209
    # [num_reqs]
210
    outputs: list[EngineCoreOutput] = []
211
    scheduler_stats: SchedulerStats | None = None
212
213
    timestamp: float = 0.0

214
215
    utility_output: UtilityOutput | None = None
    finished_requests: set[str] | None = None
216

217
218
    # In DP case, used to signal that the current wave of requests
    # has finished and the engines are paused.
219
    wave_complete: int | None = None
220
221
    # In DP case, used to signal that a request was received for an
    # "old" wave, so the next wave needs to be started in other engines.
222
    start_wave: int | None = None
223

224
225
226
    def __post_init__(self):
        if self.timestamp == 0.0:
            self.timestamp = time.monotonic()
227
228
229
230
231
232
233


class EngineCoreRequestType(enum.Enum):
    """
    Request types defined as hex byte strings, so it can be sent over sockets
    without separate encoding step.
    """
234
235
236
237
238

    ADD = b"\x00"
    ABORT = b"\x01"
    START_DP_WAVE = b"\x02"
    UTILITY = b"\x03"
239
    # Sentinel used within EngineCoreProc.
240
    EXECUTOR_FAILED = b"\x04"
241
242
    # Sentinel to wake up input_queue.get() during shutdown.
    WAKEUP = b"\x05"
243
244
245
246
247
248
249
250


class ReconfigureDistributedRequest(msgspec.Struct):
    new_data_parallel_size: int
    new_data_parallel_rank: int
    new_data_parallel_rank_local: int
    new_data_parallel_master_ip: str
    new_data_parallel_master_port: int
251
252
253
254
255
    new_data_parallel_master_port_list: list[int]
    new_stateless_world_group_port_list: list[list[int]]
    new_stateless_dp_group_port_list: list[list[int]]
    new_stateless_ep_group_port_list: list[list[int]]
    new_stateless_eplb_group_port_list: list[list[int]]
256
257
258
259
260
261


class ReconfigureRankType(enum.IntEnum):
    """
    Rank type for reconfiguring distributed request.
    """
262

263
264
    KEEP_CURRENT_RANK = -1
    SHUTDOWN_CURRENT_RANK = -2