__init__.py 6.09 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections.abc import Mapping
7
from typing import Any
8
9

import msgspec
10
import torch
11

12
from vllm.lora.request import LoRARequest
13
from vllm.multimodal.inputs import MultiModalFeatureSpec
14
from vllm.pooling_params import PoolingParams
15
from vllm.sampling_params import SamplingParams
16
from vllm.v1.metrics.stats import SchedulerStats
17
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
18
from vllm.v1.serial_utils import UtilityResult
19

20
21
22
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
23

24
25

class FinishReason(enum.IntEnum):
26
27
28
    """
    Reason a request finished - stop, length, or abort.

29
30
    Int rather than Str for more compact serialization.

31
32
33
34
35
    stop - a stop string was emitted
    length - max_tokens was consumed, or max_model_len was reached
    abort - aborted for another reason

    """
36

37
38
39
40
41
    STOP = 0
    LENGTH = 1
    ABORT = 2

    def __str__(self):
42
        return FINISH_REASON_STRINGS[self.value]
43
44


45
class EngineCoreRequest(
46
47
48
49
50
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
51
    request_id: str
52
53
54
55
56
    prompt_token_ids: list[int] | None
    mm_features: list[MultiModalFeatureSpec] | None
    sampling_params: SamplingParams | None
    pooling_params: PoolingParams | None
    eos_token_id: int | None
57
    arrival_time: float
58
59
60
61
    lora_request: LoRARequest | None
    cache_salt: str | None
    data_parallel_rank: int | None
    prompt_embeds: torch.Tensor | None = None
62

63
64
65
66
    # Index of the client, used to ensure outputs are sent back to the same
    # client for this request when scaling out the front-end.
    client_index: int = 0

67
68
69
70
    # Used in DP case to indicate which wave of requests this is expected to
    # belong to, to cover a race condition where the request is sent before
    # a wave finished notification is received.
    current_wave: int = 0
71
    priority: int = 0
72

73
    trace_headers: Mapping[str, str] | None = None
74

75
76
77
78
79
80
81
82
    @property
    def params(self) -> SamplingParams | PoolingParams:
        """Return the processed params (sampling or pooling)."""
        if self.sampling_params is not None:
            return self.sampling_params
        assert self.pooling_params is not None
        return self.pooling_params

83

84
85
class EngineCoreEventType(enum.IntEnum):
    """The type of engine core request event."""
86

87
88
    QUEUED = 1
    SCHEDULED = 2
89
    PREEMPTED = 3
90
91
92
93
94
95
96
97
98


class EngineCoreEvent(msgspec.Struct):
    """A timestamped engine core event associated with a request.

    The timestamp is a monotonic timestamps and is used for by the engine
    frontend to calculate intervals between engine core events. These
    timestamps should not be compared with timestamps from other processes.
    """
99

100
101
102
103
    type: EngineCoreEventType
    timestamp: float

    @classmethod
104
    def new_event(
105
        cls, event_type: EngineCoreEventType, timestamp: float | None = None
106
    ) -> "EngineCoreEvent":
107
108
109
110
        timestamp = time.monotonic() if timestamp is None else timestamp
        return cls(event_type, timestamp)


111
class EngineCoreOutput(
112
113
114
115
116
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
117
    request_id: str
118
    new_token_ids: list[int]
119

120
121
    new_logprobs: LogprobsLists | None = None
    new_prompt_logprobs_tensors: LogprobsTensors | None = None
122

123
    pooling_output: torch.Tensor | None = None
124

125
126
127
128
    finish_reason: FinishReason | None = None
    stop_reason: int | str | None = None
    events: list[EngineCoreEvent] | None = None
    kv_transfer_params: dict[str, Any] | None = None
129

130
    trace_headers: Mapping[str, str] | None = None
131
132
133
    # The number of tokens with prefix cache hits.
    num_cached_tokens: int = 0

134
135
136
137
    # The number of NaNs in logits.
    # A value greater than 0 indicates that the output is corrupted.
    num_nans_in_logits: int = 0

138
139
140
141
    @property
    def finished(self) -> bool:
        return self.finish_reason is not None

142

143
class UtilityOutput(
144
145
146
147
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
148
149
150
    call_id: int

    # Non-None implies the call failed, result should be None.
151
152
    failure_message: str | None = None
    result: UtilityResult | None = None
153
154


155
class EngineCoreOutputs(
156
157
158
159
160
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
161
    # NOTE(Nick): We could consider ways to make this more compact,
162
    # e.g. columnwise layout
163

164
165
    engine_index: int = 0

166
    # [num_reqs]
167
    outputs: list[EngineCoreOutput] = []
168
    scheduler_stats: SchedulerStats | None = None
169
170
    timestamp: float = 0.0

171
172
    utility_output: UtilityOutput | None = None
    finished_requests: set[str] | None = None
173

174
175
    # In DP case, used to signal that the current wave of requests
    # has finished and the engines are paused.
176
    wave_complete: int | None = None
177
178
    # In DP case, used to signal that a request was received for an
    # "old" wave, so the next wave needs to be started in other engines.
179
    start_wave: int | None = None
180

181
182
183
    def __post_init__(self):
        if self.timestamp == 0.0:
            self.timestamp = time.monotonic()
184
185
186
187
188
189
190


class EngineCoreRequestType(enum.Enum):
    """
    Request types defined as hex byte strings, so it can be sent over sockets
    without separate encoding step.
    """
191
192
193
194
195

    ADD = b"\x00"
    ABORT = b"\x01"
    START_DP_WAVE = b"\x02"
    UTILITY = b"\x03"
196
    # Sentinel used within EngineCoreProc.
197
    EXECUTOR_FAILED = b"\x04"
198
199
200
201
202
203
204
205
206
207
208
209
210
211


class ReconfigureDistributedRequest(msgspec.Struct):
    new_data_parallel_size: int
    new_data_parallel_rank: int
    new_data_parallel_rank_local: int
    new_data_parallel_master_ip: str
    new_data_parallel_master_port: int


class ReconfigureRankType(enum.IntEnum):
    """
    Rank type for reconfiguring distributed request.
    """
212

213
214
    KEEP_CURRENT_RANK = -1
    SHUTDOWN_CURRENT_RANK = -2