__init__.py 6.63 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections.abc import Mapping
7
from typing import Any
8
9

import msgspec
10
import numpy as np
11
import torch
12

13
from vllm.lora.request import LoRARequest
14
from vllm.multimodal.inputs import MultiModalFeatureSpec
15
from vllm.pooling_params import PoolingParams
16
from vllm.sampling_params import SamplingParams
17
from vllm.v1.metrics.stats import SchedulerStats
18
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
19
from vllm.v1.serial_utils import UtilityResult
20

21
22
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
23
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
24

25
26

class FinishReason(enum.IntEnum):
27
    """
28
    Reason a request finished - stop, length, abort, or error.
29

30
31
    Int rather than Str for more compact serialization.

32
33
    stop - a stop string was emitted
    length - max_tokens was consumed, or max_model_len was reached
34
35
36
    abort - aborted by client
    error - retryable request-level internal error (e.g., KV load failure).
            Invariant: always converted to 500 Internal Server Error.
37
38

    """
39

40
41
42
    STOP = 0
    LENGTH = 1
    ABORT = 2
43
    ERROR = 3
44
45

    def __str__(self):
46
        return FINISH_REASON_STRINGS[self.value]
47
48


49
class EngineCoreRequest(
50
51
52
53
54
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
55
    request_id: str
56
57
58
59
60
    prompt_token_ids: list[int] | None
    mm_features: list[MultiModalFeatureSpec] | None
    sampling_params: SamplingParams | None
    pooling_params: PoolingParams | None
    eos_token_id: int | None
61
    arrival_time: float
62
63
64
65
    lora_request: LoRARequest | None
    cache_salt: str | None
    data_parallel_rank: int | None
    prompt_embeds: torch.Tensor | None = None
66

67
68
69
70
    # Index of the client, used to ensure outputs are sent back to the same
    # client for this request when scaling out the front-end.
    client_index: int = 0

71
72
73
74
    # Used in DP case to indicate which wave of requests this is expected to
    # belong to, to cover a race condition where the request is sent before
    # a wave finished notification is received.
    current_wave: int = 0
75
    priority: int = 0
76

77
    trace_headers: Mapping[str, str] | None = None
78

79
80
81
82
83
84
    # The user-provided request ID. This field is set internally,
    # copied from the provided request_id that's originally assigned
    # to the request_id field, see InputProcessor.assign_request_id().
    # Used in outputs and to support abort(req_id, internal=False).
    external_req_id: str | None = None

85
86
87
88
89
90
91
92
    @property
    def params(self) -> SamplingParams | PoolingParams:
        """Return the processed params (sampling or pooling)."""
        if self.sampling_params is not None:
            return self.sampling_params
        assert self.pooling_params is not None
        return self.pooling_params

93

94
95
class EngineCoreEventType(enum.IntEnum):
    """The type of engine core request event."""
96

97
98
    QUEUED = 1
    SCHEDULED = 2
99
    PREEMPTED = 3
100
101
102
103
104
105
106
107
108


class EngineCoreEvent(msgspec.Struct):
    """A timestamped engine core event associated with a request.

    The timestamp is a monotonic timestamps and is used for by the engine
    frontend to calculate intervals between engine core events. These
    timestamps should not be compared with timestamps from other processes.
    """
109

110
111
112
113
    type: EngineCoreEventType
    timestamp: float

    @classmethod
114
    def new_event(
115
        cls, event_type: EngineCoreEventType, timestamp: float | None = None
116
    ) -> "EngineCoreEvent":
117
118
119
120
        timestamp = time.monotonic() if timestamp is None else timestamp
        return cls(event_type, timestamp)


121
class EngineCoreOutput(
122
123
124
125
126
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
127
    request_id: str
128
    new_token_ids: list[int]
129

130
131
    new_logprobs: LogprobsLists | None = None
    new_prompt_logprobs_tensors: LogprobsTensors | None = None
132

133
    pooling_output: torch.Tensor | None = None
134

135
136
137
138
    finish_reason: FinishReason | None = None
    stop_reason: int | str | None = None
    events: list[EngineCoreEvent] | None = None
    kv_transfer_params: dict[str, Any] | None = None
139

140
    trace_headers: Mapping[str, str] | None = None
141
142
    # The number of tokens with prefix cache hits.
    num_cached_tokens: int = 0
143
    routed_experts: np.ndarray | None = None
144
145
146
147
    # The number of NaNs in logits.
    # A value greater than 0 indicates that the output is corrupted.
    num_nans_in_logits: int = 0

148
149
150
151
    @property
    def finished(self) -> bool:
        return self.finish_reason is not None

152

153
class UtilityOutput(
154
155
156
157
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
158
159
160
    call_id: int

    # Non-None implies the call failed, result should be None.
161
162
    failure_message: str | None = None
    result: UtilityResult | None = None
163
164


165
class EngineCoreOutputs(
166
167
168
169
170
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
171
    # NOTE(Nick): We could consider ways to make this more compact,
172
    # e.g. columnwise layout
173

174
175
    engine_index: int = 0

176
    # [num_reqs]
177
    outputs: list[EngineCoreOutput] = []
178
    scheduler_stats: SchedulerStats | None = None
179
180
    timestamp: float = 0.0

181
182
    utility_output: UtilityOutput | None = None
    finished_requests: set[str] | None = None
183

184
185
    # In DP case, used to signal that the current wave of requests
    # has finished and the engines are paused.
186
    wave_complete: int | None = None
187
188
    # In DP case, used to signal that a request was received for an
    # "old" wave, so the next wave needs to be started in other engines.
189
    start_wave: int | None = None
190

191
192
193
    def __post_init__(self):
        if self.timestamp == 0.0:
            self.timestamp = time.monotonic()
194
195
196
197
198
199
200


class EngineCoreRequestType(enum.Enum):
    """
    Request types defined as hex byte strings, so it can be sent over sockets
    without separate encoding step.
    """
201
202
203
204
205

    ADD = b"\x00"
    ABORT = b"\x01"
    START_DP_WAVE = b"\x02"
    UTILITY = b"\x03"
206
    # Sentinel used within EngineCoreProc.
207
    EXECUTOR_FAILED = b"\x04"
208
209
210
211
212
213
214
215
216
217
218
219
220
221


class ReconfigureDistributedRequest(msgspec.Struct):
    new_data_parallel_size: int
    new_data_parallel_rank: int
    new_data_parallel_rank_local: int
    new_data_parallel_master_ip: str
    new_data_parallel_master_port: int


class ReconfigureRankType(enum.IntEnum):
    """
    Rank type for reconfiguring distributed request.
    """
222

223
224
    KEEP_CURRENT_RANK = -1
    SHUTDOWN_CURRENT_RANK = -2