__init__.py 7.13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections.abc import Mapping
7
from typing import Any, Literal
8
9

import msgspec
10
import numpy as np
11
import torch
12

13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams
17
from vllm.v1.metrics.stats import SchedulerStats
18
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
19
from vllm.v1.serial_utils import UtilityResult
20

21
22
23
24
25
26
# Type for pause_generation mode parameter.
# - "abort": Abort all in-flight requests immediately (default).
# - "wait": Wait for in-flight requests to complete before pausing.
# - "keep": Freeze requests in queue; they resume on resume_generation().
PauseMode = Literal["abort", "wait", "keep"]

27
28
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
29
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
30

31
32

class FinishReason(enum.IntEnum):
33
    """
34
    Reason a request finished - stop, length, abort, or error.
35

36
37
    Int rather than Str for more compact serialization.

38
39
    stop - a stop string was emitted
    length - max_tokens was consumed, or max_model_len was reached
40
41
42
    abort - aborted by client
    error - retryable request-level internal error (e.g., KV load failure).
            Invariant: always converted to 500 Internal Server Error.
43
44

    """
45

46
47
48
    STOP = 0
    LENGTH = 1
    ABORT = 2
49
    ERROR = 3
50
51

    def __str__(self):
52
        return FINISH_REASON_STRINGS[self.value]
53
54


55
class EngineCoreRequest(
56
57
58
59
60
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
61
    request_id: str
62
63
64
65
66
    prompt_token_ids: list[int] | None
    mm_features: list[MultiModalFeatureSpec] | None
    sampling_params: SamplingParams | None
    pooling_params: PoolingParams | None
    eos_token_id: int | None
67
    arrival_time: float
68
69
70
71
    lora_request: LoRARequest | None
    cache_salt: str | None
    data_parallel_rank: int | None
    prompt_embeds: torch.Tensor | None = None
72

73
74
75
76
    # Index of the client, used to ensure outputs are sent back to the same
    # client for this request when scaling out the front-end.
    client_index: int = 0

77
78
79
80
    # Used in DP case to indicate which wave of requests this is expected to
    # belong to, to cover a race condition where the request is sent before
    # a wave finished notification is received.
    current_wave: int = 0
81
    priority: int = 0
82

83
    trace_headers: Mapping[str, str] | None = None
84
    resumable: bool = False
85

86
87
88
89
90
91
    # The user-provided request ID. This field is set internally,
    # copied from the provided request_id that's originally assigned
    # to the request_id field, see InputProcessor.assign_request_id().
    # Used in outputs and to support abort(req_id, internal=False).
    external_req_id: str | None = None

92
93
    reasoning_ended: bool | None = None

94
95
96
97
98
99
100
101
    @property
    def params(self) -> SamplingParams | PoolingParams:
        """Return the processed params (sampling or pooling)."""
        if self.sampling_params is not None:
            return self.sampling_params
        assert self.pooling_params is not None
        return self.pooling_params

102

103
104
class EngineCoreEventType(enum.IntEnum):
    """The type of engine core request event."""
105

106
107
    QUEUED = 1
    SCHEDULED = 2
108
    PREEMPTED = 3
109
110
111
112
113
114
115
116
117


class EngineCoreEvent(msgspec.Struct):
    """A timestamped engine core event associated with a request.

    The timestamp is a monotonic timestamps and is used for by the engine
    frontend to calculate intervals between engine core events. These
    timestamps should not be compared with timestamps from other processes.
    """
118

119
120
121
122
    type: EngineCoreEventType
    timestamp: float

    @classmethod
123
    def new_event(
124
        cls, event_type: EngineCoreEventType, timestamp: float | None = None
125
    ) -> "EngineCoreEvent":
126
127
128
129
        timestamp = time.monotonic() if timestamp is None else timestamp
        return cls(event_type, timestamp)


130
class EngineCoreOutput(
131
132
133
134
135
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
136
    request_id: str
137
    new_token_ids: list[int]
138

139
140
    new_logprobs: LogprobsLists | None = None
    new_prompt_logprobs_tensors: LogprobsTensors | None = None
141

142
    pooling_output: torch.Tensor | None = None
143

144
145
146
147
    finish_reason: FinishReason | None = None
    stop_reason: int | str | None = None
    events: list[EngineCoreEvent] | None = None
    kv_transfer_params: dict[str, Any] | None = None
148

149
    trace_headers: Mapping[str, str] | None = None
150
    # The number of tokens with prefix cache hits (local + external).
151
    num_cached_tokens: int = 0
152
153
    # The number of tokens computed remotely (original count from connector).
    num_external_computed_tokens: int = 0
154
    routed_experts: np.ndarray | None = None
155
156
157
158
    # The number of NaNs in logits.
    # A value greater than 0 indicates that the output is corrupted.
    num_nans_in_logits: int = 0

159
160
161
162
    @property
    def finished(self) -> bool:
        return self.finish_reason is not None

163

164
class UtilityOutput(
165
166
167
168
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
169
170
171
    call_id: int

    # Non-None implies the call failed, result should be None.
172
173
    failure_message: str | None = None
    result: UtilityResult | None = None
174
175


176
class EngineCoreOutputs(
177
178
179
180
181
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
182
    # NOTE(Nick): We could consider ways to make this more compact,
183
    # e.g. columnwise layout
184

185
186
    engine_index: int = 0

187
    # [num_reqs]
188
    outputs: list[EngineCoreOutput] = []
189
    scheduler_stats: SchedulerStats | None = None
190
191
    timestamp: float = 0.0

192
193
    utility_output: UtilityOutput | None = None
    finished_requests: set[str] | None = None
194

195
196
    # In DP case, used to signal that the current wave of requests
    # has finished and the engines are paused.
197
    wave_complete: int | None = None
198
199
    # In DP case, used to signal that a request was received for an
    # "old" wave, so the next wave needs to be started in other engines.
200
    start_wave: int | None = None
201

202
203
204
    def __post_init__(self):
        if self.timestamp == 0.0:
            self.timestamp = time.monotonic()
205
206
207
208
209
210
211


class EngineCoreRequestType(enum.Enum):
    """
    Request types defined as hex byte strings, so it can be sent over sockets
    without separate encoding step.
    """
212
213
214
215
216

    ADD = b"\x00"
    ABORT = b"\x01"
    START_DP_WAVE = b"\x02"
    UTILITY = b"\x03"
217
    # Sentinel used within EngineCoreProc.
218
    EXECUTOR_FAILED = b"\x04"
219
220
221
222
223
224
225
226
227
228
229
230
231
232


class ReconfigureDistributedRequest(msgspec.Struct):
    new_data_parallel_size: int
    new_data_parallel_rank: int
    new_data_parallel_rank_local: int
    new_data_parallel_master_ip: str
    new_data_parallel_master_port: int


class ReconfigureRankType(enum.IntEnum):
    """
    Rank type for reconfiguring distributed request.
    """
233

234
235
    KEEP_CURRENT_RANK = -1
    SHUTDOWN_CURRENT_RANK = -2