scheduler.py 12.7 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from collections.abc import Callable
from dataclasses import InitVar
6
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
7

8
from pydantic import Field, field_validator
9
from typing_extensions import Self
10
11
12

from vllm.config.utils import config
from vllm.logger import init_logger
13
from vllm.utils.hashing import safe_hash
14
15
16
17
from vllm.utils.import_utils import resolve_obj_by_qualname

if TYPE_CHECKING:
    from vllm.v1.core.sched.interface import SchedulerInterface
18
19
20

logger = init_logger(__name__)

21
RunnerType = Literal["generate", "pooling", "draft"]
22
23
24
25
26
27
28
SchedulerPolicy = Literal["fcfs", "priority"]


@config
class SchedulerConfig:
    """Scheduler configuration."""

29
30
31
32
33
34
35
36
37
38
39
40
41
    max_model_len: InitVar[int]
    """Maximum length of a sequence (including prompt and generated text).

    Note: This is stored in the ModelConfig, and is used only here to
    provide fallbacks and validate other attributes."""

    is_encoder_decoder: InitVar[bool]
    """True if the model is an encoder-decoder model.

    Note: This is stored in the ModelConfig, and is used only here to
    disable chunked prefill and prefix caching for encoder-decoder models.
    """

42
    DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
43
    DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP: ClassVar[int] = 256
44
45
    DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128

46
47
48
    runner_type: RunnerType = "generate"
    """The runner type to launch for the model."""

49
    max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
50
    """Maximum number of tokens that can be processed in a single iteration.
51

52
53
54
    The default value here is mainly for convenience when testing.
    In real usage, this should be set in `EngineArgs.create_engine_config`.
    """
55

56
    max_num_scheduled_tokens: int | None = None
57
58
59
60
61
62
    """Maximum number of tokens that the scheduler may issue in a single iteration.
    
    This is usually equal to max_num_batched_tokens, but can be smaller in cases
    when the model might append tokens into the batch (such as speculative decoding).
    Defaults to max_num_batched_tokens."""

63
    max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
64
65
    """Maximum number of sequences to be processed in a single iteration.

66
67
68
69
    The default value here is mainly for convenience when testing.
    In real usage, this should be set in `EngineArgs.create_engine_config`.
    """

70
    max_num_partial_prefills: int = Field(default=1, ge=1)
71
72
73
    """For chunked prefill, the maximum number of sequences that can be
    partially prefilled concurrently."""

74
    max_long_partial_prefills: int = Field(default=1, ge=1)
75
76
77
78
79
80
81
82
83
    """For chunked prefill, the maximum number of prompts longer than
    long_prefill_token_threshold that will be prefilled concurrently. Setting
    this less than max_num_partial_prefills will allow shorter prompts to jump
    the queue in front of longer prompts in some cases, improving latency."""

    long_prefill_token_threshold: int = 0
    """For chunked prefill, a request is considered long if the prompt is
    longer than this number of tokens."""

84
    enable_chunked_prefill: bool = True
85
    """If True, prefill requests can be chunked based
86
87
88
89
90
    on the remaining `max_num_batched_tokens`.

    The default value here is mainly for convenience when testing.
    In real usage, this should be set in `EngineArgs.create_engine_config`.
    """
91
92
93
94
95

    is_multimodal_model: bool = False
    """True if the model is multimodal."""

    # TODO (ywang96): Make this configurable.
96
    max_num_encoder_input_tokens: int = Field(init=False)
97
98
99
100
101
102
    """Multimodal encoder compute budget, only used in V1.

    NOTE: This is not currently configurable. It will be overridden by
    max_num_batched_tokens in case max multimodal embedding size is larger."""

    # TODO (ywang96): Make this configurable.
103
    encoder_cache_size: int = Field(init=False)
104
105
106
107
108
109
    """Multimodal encoder cache size, only used in V1.

    NOTE: This is not currently configurable. It will be overridden by
    max_num_batched_tokens in case max multimodal embedding size is larger."""

    policy: SchedulerPolicy = "fcfs"
110
111
112
113
    """The scheduling policy to use:

    - "fcfs" means first come first served, i.e. requests are handled in order 
      of arrival.
114
    - "priority" means requests are handled based on given priority (lower
115
      value means earlier handling) and time of arrival deciding any ties)."""
116
117
118
119
120
121
122
123
124

    disable_chunked_mm_input: bool = False
    """If set to true and chunked prefill is enabled, we do not want to
    partially schedule a multimodal item. Only used in V1
    This ensures that if a request has a mixed prompt
    (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
    some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
    it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""

125
126
    # scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
    # (default) or "mod.custom_class".
127
    scheduler_cls: str | type[object] | None = None
128
129
130
    """The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
    the default scheduler. Can be a class directly or the path to a class of
    form "mod.custom_class"."""
131

132
    disable_hybrid_kv_cache_manager: bool | None = None
133
134
135
    """If set to True, KV cache manager will allocate the same size of KV cache
    for all attention layers even if there are multiple type of attention layers
    like full attention and sliding window attention.
136
137
    If set to None, the default value will be determined based on the environment
    and starting configuration.
138
139
    """

140
141
142
143
144
145
    scheduler_reserve_full_isl: bool = True
    """If True, the scheduler checks whether the full input sequence length
    fits in the KV cache before admitting a new request, rather than only
    checking the first chunk. Prevents over-admission and KV cache thrashing
    with chunked prefill."""

146
    async_scheduling: bool | None = None
147
148
    """If set to False, disable async scheduling. Async scheduling helps to
    avoid gaps in GPU utilization, leading to better latency and throughput.
149
150
    """

151
152
153
154
155
156
    stream_interval: int = Field(default=1, ge=1)
    """The interval (or buffer size) for streaming in terms of token length.
    A smaller value (1) makes streaming smoother by sending each token immediately,
    while a larger value (e.g., 10) reduces host overhead and may increase throughput
    by batching multiple tokens before sending."""

157
158
159
160
161
162
163
164
165
166
167
    @staticmethod
    def default_factory(**kwargs):
        """
        Factory method to create `SchedulerConfig` with default values for `InitVar`s.
        """
        if "max_model_len" not in kwargs:
            kwargs["max_model_len"] = 8192
        if "is_encoder_decoder" not in kwargs:
            kwargs["is_encoder_decoder"] = False
        return SchedulerConfig(**kwargs)

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    def get_scheduler_cls(self) -> type["SchedulerInterface"]:
        if self.scheduler_cls is None:
            if self.async_scheduling:
                from vllm.v1.core.sched.async_scheduler import AsyncScheduler

                return AsyncScheduler
            from vllm.v1.core.sched.scheduler import Scheduler

            return Scheduler

        # This warning can be removed once the Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        logger.warning_once(
            "Using custom scheduler class %s. This scheduler interface is "
            "not public and compatibility may not be maintained.",
184
            self.scheduler_cls,  # type: ignore[arg-type]
185
186
187
188
189
        )
        if not isinstance(self.scheduler_cls, str):
            return cast(type["SchedulerInterface"], self.scheduler_cls)
        return resolve_obj_by_qualname(self.scheduler_cls)

190
191
192
193
194
195
196
197
198
199
200
201
202
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
203
204
205
206
207
208
209
210
211
212
213
214

        # max_num_batched_tokens need to be included in the hash due
        # to two reasons:
        # 1. LoRA creates static buffers based on max_num_batched_tokens.
        #   The tensor sizes and strides get captured in the torch.compile
        #   graph explicitly.
        # 2. Inductor decides whether using 32-bit or 64-bit indexing integer
        #   based on the data sizes. `max_num_batched_tokens` has an
        #   impact on that. For more details, please check
        #   https://github.com/vllm-project/vllm/issues/29585
        factors.append(self.max_num_batched_tokens)

215
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
216
217
        return hash_str

218
    @field_validator("scheduler_cls", "async_scheduling", mode="wrap")
219
220
221
    @classmethod
    def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
        """Skip validation if the value is `None` when initialisation is delayed."""
222
        return None if value is None else handler(value)
223

224
    def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None:
225
226
227
228
229
230
231
        if is_encoder_decoder:
            # Chunked prefill should be disabled for encoder-decoder models.
            self.disable_chunked_mm_input = True
            self.enable_chunked_prefill = False
            self.long_prefill_token_threshold = 0
            logger.info(
                "Encoder-decoder models do not support chunked prefill nor"
232
233
                " prefix caching; disabling both."
            )
234

235
236
237
238
        self.max_num_encoder_input_tokens = self.max_num_batched_tokens
        self.encoder_cache_size = self.max_num_batched_tokens

        if self.enable_chunked_prefill:
239
            logger.info_once(
240
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
241
242
                self.max_num_batched_tokens,
            )
243
244
245

        if self.max_num_partial_prefills > 1:
            if self.long_prefill_token_threshold == 0:
246
                self.long_prefill_token_threshold = int(max_model_len * 0.04)
247
248
249
250
251

            logger.info(
                "Concurrent partial prefills enabled with "
                "max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
                "long_prefill_token_threshold=%d",
252
253
254
255
                self.max_num_partial_prefills,
                self.max_long_partial_prefills,
                self.long_prefill_token_threshold,
            )
256

257
258
259
        self.verify_max_model_len(max_model_len)

    def verify_max_model_len(self, max_model_len: int) -> Self:
260
        if (
261
            self.max_num_batched_tokens < max_model_len
262
            and not self.enable_chunked_prefill
263
        ):
264
265
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
266
                f"smaller than max_model_len ({max_model_len}). "
267
268
269
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
270
271
                "decrease max_model_len."
            )
272
273
274
275
276

        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
277
278
                f"({self.max_num_seqs})."
            )
279

280
        if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
281
282
283
284
            logger.warning(
                "max_num_batched_tokens (%d) exceeds max_num_seqs "
                "* max_model_len (%d). This may lead to unexpected behavior.",
                self.max_num_batched_tokens,
285
                self.max_num_seqs * max_model_len,
286
            )
287

288
        if self.max_num_partial_prefills > 1:
289
            if not self.enable_chunked_prefill:
290
291
292
293
                raise ValueError(
                    "Chunked prefill must be enabled to set "
                    "max_num_partial_prefills > 1."
                )
294

295
            if self.long_prefill_token_threshold > max_model_len:
296
297
298
                raise ValueError(
                    "long_prefill_token_threshold "
                    f"({self.long_prefill_token_threshold}) cannot be greater "
299
                    f"than the max_model_len ({max_model_len})."
300
                )
301

302
        if self.max_long_partial_prefills > self.max_num_partial_prefills:
303
            raise ValueError(
304
305
                f"{self.max_long_partial_prefills=} must be less than or equal to "
                f"{self.max_num_partial_prefills=}."
306
            )
307
308

        return self