__init__.py 5.87 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from collections.abc import Mapping
7
from typing import Any, Optional, Union
8
9

import msgspec
10
import torch
11

12
from vllm.lora.request import LoRARequest
13
from vllm.multimodal.inputs import MultiModalFeatureSpec
14
from vllm.pooling_params import PoolingParams
15
from vllm.sampling_params import SamplingParams
16
from vllm.v1.metrics.stats import SchedulerStats
17
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
18

19
20
21
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
22

23
24

class FinishReason(enum.IntEnum):
25
26
27
    """
    Reason a request finished - stop, length, or abort.

28
29
    Int rather than Str for more compact serialization.

30
31
32
33
34
    stop - a stop string was emitted
    length - max_tokens was consumed, or max_model_len was reached
    abort - aborted for another reason

    """
35

36
37
38
39
40
    STOP = 0
    LENGTH = 1
    ABORT = 2

    def __str__(self):
41
        return FINISH_REASON_STRINGS[self.value]
42
43


44
class EngineCoreRequest(
45
46
47
48
49
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
50
    request_id: str
51
    prompt_token_ids: Optional[list[int]]
52
    mm_features: Optional[list[MultiModalFeatureSpec]]
53
54
    sampling_params: Optional[SamplingParams]
    pooling_params: Optional[PoolingParams]
55
56
    eos_token_id: Optional[int]
    arrival_time: float
57
    lora_request: Optional[LoRARequest]
58
    cache_salt: Optional[str]
59
    data_parallel_rank: Optional[int]
60
    prompt_embeds: Optional[torch.Tensor] = None
61

62
63
64
65
    # Index of the client, used to ensure outputs are sent back to the same
    # client for this request when scaling out the front-end.
    client_index: int = 0

66
67
68
69
    # Used in DP case to indicate which wave of requests this is expected to
    # belong to, to cover a race condition where the request is sent before
    # a wave finished notification is received.
    current_wave: int = 0
70
    priority: int = 0
71

72
73
    trace_headers: Optional[Mapping[str, str]] = None

74

75
76
class EngineCoreEventType(enum.IntEnum):
    """The type of engine core request event."""
77

78
79
    QUEUED = 1
    SCHEDULED = 2
80
    PREEMPTED = 3
81
82
83
84
85
86
87
88
89


class EngineCoreEvent(msgspec.Struct):
    """A timestamped engine core event associated with a request.

    The timestamp is a monotonic timestamps and is used for by the engine
    frontend to calculate intervals between engine core events. These
    timestamps should not be compared with timestamps from other processes.
    """
90

91
92
93
94
    type: EngineCoreEventType
    timestamp: float

    @classmethod
95
96
97
    def new_event(
        cls, event_type: EngineCoreEventType, timestamp: Optional[float] = None
    ) -> "EngineCoreEvent":
98
99
100
101
        timestamp = time.monotonic() if timestamp is None else timestamp
        return cls(event_type, timestamp)


102
class EngineCoreOutput(
103
104
105
106
107
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
108
    request_id: str
109
    new_token_ids: list[int]
110
111
112
113

    new_logprobs: Optional[LogprobsLists] = None
    new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None

114
115
    pooling_output: Optional[torch.Tensor] = None

116
    finish_reason: Optional[FinishReason] = None
117
    stop_reason: Union[int, str, None] = None
118
    events: Optional[list[EngineCoreEvent]] = None
Robert Shaw's avatar
Robert Shaw committed
119
    kv_transfer_params: Optional[dict[str, Any]] = None
120

121
    trace_headers: Optional[Mapping[str, str]] = None
122
123
124
    # The number of tokens with prefix cache hits.
    num_cached_tokens: int = 0

125
126
127
128
    @property
    def finished(self) -> bool:
        return self.finish_reason is not None

129

130
131
132
133
134
135
136
class UtilityResult:
    """Wrapper for special handling when serializing/deserializing."""

    def __init__(self, r: Any = None):
        self.result = r


137
class UtilityOutput(
138
139
140
141
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
142
143
144
145
    call_id: int

    # Non-None implies the call failed, result should be None.
    failure_message: Optional[str] = None
146
    result: Optional[UtilityResult] = None
147
148


149
class EngineCoreOutputs(
150
151
152
153
154
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,
):  # type: ignore[call-arg]
155
    # NOTE(Nick): We could consider ways to make this more compact,
156
    # e.g. columnwise layout
157

158
159
    engine_index: int = 0

160
    # [num_reqs]
161
    outputs: list[EngineCoreOutput] = []
162
    scheduler_stats: Optional[SchedulerStats] = None
163
164
    timestamp: float = 0.0

165
    utility_output: Optional[UtilityOutput] = None
166
167
    finished_requests: Optional[set[str]] = None

168
169
170
171
172
173
    # In DP case, used to signal that the current wave of requests
    # has finished and the engines are paused.
    wave_complete: Optional[int] = None
    # In DP case, used to signal that a request was received for an
    # "old" wave, so the next wave needs to be started in other engines.
    start_wave: Optional[int] = None
174

175
176
177
    def __post_init__(self):
        if self.timestamp == 0.0:
            self.timestamp = time.monotonic()
178
179
180
181
182
183
184


class EngineCoreRequestType(enum.Enum):
    """
    Request types defined as hex byte strings, so it can be sent over sockets
    without separate encoding step.
    """
185
186
187
188
189

    ADD = b"\x00"
    ABORT = b"\x01"
    START_DP_WAVE = b"\x02"
    UTILITY = b"\x03"
190
    # Sentinel used within EngineCoreProc.
191
    EXECUTOR_FAILED = b"\x04"
192
193
194
195
196
197
198
199
200
201
202
203
204
205


class ReconfigureDistributedRequest(msgspec.Struct):
    new_data_parallel_size: int
    new_data_parallel_rank: int
    new_data_parallel_rank_local: int
    new_data_parallel_master_ip: str
    new_data_parallel_master_port: int


class ReconfigureRankType(enum.IntEnum):
    """
    Rank type for reconfiguring distributed request.
    """
206

207
208
    KEEP_CURRENT_RANK = -1
    SHUTDOWN_CURRENT_RANK = -2