scheduler.py 12 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import hashlib
5
6
from collections.abc import Callable
from dataclasses import InitVar
7
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
8

9
from pydantic import Field, field_validator, model_validator
10
from pydantic.dataclasses import dataclass
11
from typing_extensions import Self, deprecated
12
13
14

from vllm.config.utils import config
from vllm.logger import init_logger
15
16
17
18
from vllm.utils.import_utils import resolve_obj_by_qualname

if TYPE_CHECKING:
    from vllm.v1.core.sched.interface import SchedulerInterface
19
20
21

logger = init_logger(__name__)

22
RunnerType = Literal["generate", "pooling", "draft"]
23
24
25
26
27
28
29
30
SchedulerPolicy = Literal["fcfs", "priority"]


@config
@dataclass
class SchedulerConfig:
    """Scheduler configuration."""

31
32
33
    DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
    DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128

34
35
36
    runner_type: RunnerType = "generate"
    """The runner type to launch for the model."""

37
    max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
38
39
    """Maximum number of tokens to be processed in a single iteration.

40
41
42
    The default value here is mainly for convenience when testing.
    In real usage, this should be set in `EngineArgs.create_engine_config`.
    """
43

44
    max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
45
46
    """Maximum number of sequences to be processed in a single iteration.

47
48
49
50
51
52
    The default value here is mainly for convenience when testing.
    In real usage, this should be set in `EngineArgs.create_engine_config`.
    """

    max_model_len: int = Field(default=8192, ge=1)
    """Maximum length of a sequence (including prompt and generated text).
53

54
55
56
    The default value here is mainly for convenience when testing.
    In real usage, this should duplicate `ModelConfig.max_model_len` via
    `EngineArgs`."""
57

58
    max_num_partial_prefills: int = Field(default=1, ge=1)
59
60
61
    """For chunked prefill, the maximum number of sequences that can be
    partially prefilled concurrently."""

62
    max_long_partial_prefills: int = Field(default=1, ge=1)
63
64
65
66
67
68
69
70
71
    """For chunked prefill, the maximum number of prompts longer than
    long_prefill_token_threshold that will be prefilled concurrently. Setting
    this less than max_num_partial_prefills will allow shorter prompts to jump
    the queue in front of longer prompts in some cases, improving latency."""

    long_prefill_token_threshold: int = 0
    """For chunked prefill, a request is considered long if the prompt is
    longer than this number of tokens."""

72
    num_lookahead_slots: int = Field(default=0, ge=0)
73
74
75
76
77
78
79
80
    """The number of slots to allocate per sequence per
    step, beyond the known token ids. This is used in speculative
    decoding to store KV activations of tokens which may or may not be
    accepted.

    NOTE: This will be replaced by speculative config in the future; it is
    present to enable correctness tests until then."""

81
    enable_chunked_prefill: bool = True
82
    """If True, prefill requests can be chunked based
83
84
85
86
87
    on the remaining `max_num_batched_tokens`.

    The default value here is mainly for convenience when testing.
    In real usage, this should be set in `EngineArgs.create_engine_config`.
    """
88
89
90
91

    is_multimodal_model: bool = False
    """True if the model is multimodal."""

92
93
94
95
96
97
98
    is_encoder_decoder: InitVar[bool] = False
    """True if the model is an encoder-decoder model.

    Note: This is stored in the ModelConfig, and is used only here to
    disable chunked prefill and prefix caching for encoder-decoder models.
    """

99
    # TODO (ywang96): Make this configurable.
100
    max_num_encoder_input_tokens: int = Field(init=False)
101
102
103
104
105
106
    """Multimodal encoder compute budget, only used in V1.

    NOTE: This is not currently configurable. It will be overridden by
    max_num_batched_tokens in case max multimodal embedding size is larger."""

    # TODO (ywang96): Make this configurable.
107
    encoder_cache_size: int = Field(init=False)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    """Multimodal encoder cache size, only used in V1.

    NOTE: This is not currently configurable. It will be overridden by
    max_num_batched_tokens in case max multimodal embedding size is larger."""

    policy: SchedulerPolicy = "fcfs"
    """The scheduling policy to use:\n
    - "fcfs" means first come first served, i.e. requests are handled in order
    of arrival.\n
    - "priority" means requests are handled based on given priority (lower
    value means earlier handling) and time of arrival deciding any ties)."""

    disable_chunked_mm_input: bool = False
    """If set to true and chunked prefill is enabled, we do not want to
    partially schedule a multimodal item. Only used in V1
    This ensures that if a request has a mixed prompt
    (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
    some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
    it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""

128
129
    # scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
    # (default) or "mod.custom_class".
130
    scheduler_cls: str | type[object] = Field(default=None)
131
132
133
    """The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
    the default scheduler. Can be a class directly or the path to a class of
    form "mod.custom_class"."""
134
135
136
137
138
139
140
141

    disable_hybrid_kv_cache_manager: bool = False
    """If set to True, KV cache manager will allocate the same size of KV cache
    for all attention layers even if there are multiple type of attention layers
    like full attention and sliding window attention.
    """

    async_scheduling: bool = False
142
143
144
145
    """If set to True, perform async scheduling. This helps to avoid gaps in
    GPU utilization, leading to better latency and throughput.
    Async scheduling is currently not supported with some features such as
    speculative decoding and pipeline parallelism.
146
147
    """

148
149
150
151
152
153
    stream_interval: int = Field(default=1, ge=1)
    """The interval (or buffer size) for streaming in terms of token length.
    A smaller value (1) makes streaming smoother by sending each token immediately,
    while a larger value (e.g., 10) reduces host overhead and may increase throughput
    by batching multiple tokens before sending."""

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    def get_scheduler_cls(self) -> type["SchedulerInterface"]:
        if self.scheduler_cls is None:
            if self.async_scheduling:
                from vllm.v1.core.sched.async_scheduler import AsyncScheduler

                return AsyncScheduler
            from vllm.v1.core.sched.scheduler import Scheduler

            return Scheduler

        # This warning can be removed once the Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        logger.warning_once(
            "Using custom scheduler class %s. This scheduler interface is "
            "not public and compatibility may not be maintained.",
            self.scheduler_cls,
        )
        if not isinstance(self.scheduler_cls, str):
            return cast(type["SchedulerInterface"], self.scheduler_cls)
        return resolve_obj_by_qualname(self.scheduler_cls)

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: list[Any] = []
191
        hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
192
193
        return hash_str

194
    @field_validator("scheduler_cls", "async_scheduling", mode="wrap")
195
196
197
198
199
200
201
    @classmethod
    def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
        """Skip validation if the value is `None` when initialisation is delayed."""
        if value is None:
            return value
        return handler(value)

202
203
204
205
206
207
208
209
    def __post_init__(self, is_encoder_decoder: bool) -> None:
        if is_encoder_decoder:
            # Chunked prefill should be disabled for encoder-decoder models.
            self.disable_chunked_mm_input = True
            self.enable_chunked_prefill = False
            self.long_prefill_token_threshold = 0
            logger.info(
                "Encoder-decoder models do not support chunked prefill nor"
210
211
                " prefix caching; disabling both."
            )
212

213
214
215
216
217
218
        self.max_num_encoder_input_tokens = self.max_num_batched_tokens
        self.encoder_cache_size = self.max_num_batched_tokens

        if self.enable_chunked_prefill:
            logger.info(
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
219
220
                self.max_num_batched_tokens,
            )
221
222
223

        if self.max_num_partial_prefills > 1:
            if self.long_prefill_token_threshold == 0:
224
                self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
225
226
227
228
229

            logger.info(
                "Concurrent partial prefills enabled with "
                "max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
                "long_prefill_token_threshold=%d",
230
231
232
233
                self.max_num_partial_prefills,
                self.max_long_partial_prefills,
                self.long_prefill_token_threshold,
            )
234

235
    @property
236
237
238
239
240
    @deprecated(
        "`SchedulerConfig.chunked_prefill_enabled` has been renamed to "
        "`SchedulerConfig.enable_chunked_prefill`. "
        "The old name will be removed in v0.12."
    )
241
242
243
244
245
246
247
    def chunked_prefill_enabled(self) -> bool:
        return self.enable_chunked_prefill

    @chunked_prefill_enabled.setter
    def chunked_prefill_enabled(self, value: bool):
        self.enable_chunked_prefill = value

248
    @model_validator(mode="after")
249
    def _verify_args(self) -> Self:
250
251
        if (
            self.max_num_batched_tokens < self.max_model_len
252
            and not self.enable_chunked_prefill
253
        ):
254
255
256
257
258
259
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
260
261
                "decrease max_model_len."
            )
262
263
264
265
266

        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
267
268
                f"({self.max_num_seqs})."
            )
269
270
271
272
273
274

        if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
            logger.warning(
                "max_num_batched_tokens (%d) exceeds max_num_seqs "
                "* max_model_len (%d). This may lead to unexpected behavior.",
                self.max_num_batched_tokens,
275
276
                self.max_num_seqs * self.max_model_len,
            )
277

278
        if self.max_num_partial_prefills > 1:
279
            if not self.enable_chunked_prefill:
280
281
282
283
                raise ValueError(
                    "Chunked prefill must be enabled to set "
                    "max_num_partial_prefills > 1."
                )
284
285
286
287
288

            if self.long_prefill_token_threshold > self.max_model_len:
                raise ValueError(
                    "long_prefill_token_threshold "
                    f"({self.long_prefill_token_threshold}) cannot be greater "
289
290
                    f"than the max_model_len ({self.max_model_len})."
                )
291

292
        if self.max_long_partial_prefills > self.max_num_partial_prefills:
293
            raise ValueError(
294
295
                f"{self.max_long_partial_prefills=} must be less than or equal to "
                f"{self.max_num_partial_prefills=}."
296
            )
297
298

        return self