scheduler.py 12.4 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from collections.abc import Callable
from dataclasses import InitVar
6
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
7

8
from pydantic import Field, field_validator
9
from typing_extensions import Self
10
11
12

from vllm.config.utils import config
from vllm.logger import init_logger
13
from vllm.utils.hashing import safe_hash
14
15
16
17
from vllm.utils.import_utils import resolve_obj_by_qualname

if TYPE_CHECKING:
    from vllm.v1.core.sched.interface import SchedulerInterface
18
19
20

logger = init_logger(__name__)

21
RunnerType = Literal["generate", "pooling", "draft"]
22
23
24
25
26
27
28
SchedulerPolicy = Literal["fcfs", "priority"]


@config
class SchedulerConfig:
    """Scheduler configuration."""

29
30
31
32
33
34
35
36
37
38
39
40
41
    max_model_len: InitVar[int]
    """Maximum length of a sequence (including prompt and generated text).

    Note: This is stored in the ModelConfig, and is used only here to
    provide fallbacks and validate other attributes."""

    is_encoder_decoder: InitVar[bool]
    """True if the model is an encoder-decoder model.

    Note: This is stored in the ModelConfig, and is used only here to
    disable chunked prefill and prefix caching for encoder-decoder models.
    """

42
43
44
    DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
    DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128

45
46
47
    runner_type: RunnerType = "generate"
    """The runner type to launch for the model."""

48
    max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
49
    """Maximum number of tokens that can be processed in a single iteration.
50

51
52
53
    The default value here is mainly for convenience when testing.
    In real usage, this should be set in `EngineArgs.create_engine_config`.
    """
54

55
56
57
58
59
60
61
    max_num_scheduled_tokens: int | None = Field(default=None)
    """Maximum number of tokens that the scheduler may issue in a single iteration.
    
    This is usually equal to max_num_batched_tokens, but can be smaller in cases
    when the model might append tokens into the batch (such as speculative decoding).
    Defaults to max_num_batched_tokens."""

62
    max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
63
64
    """Maximum number of sequences to be processed in a single iteration.

65
66
67
68
    The default value here is mainly for convenience when testing.
    In real usage, this should be set in `EngineArgs.create_engine_config`.
    """

69
    max_num_partial_prefills: int = Field(default=1, ge=1)
70
71
72
    """For chunked prefill, the maximum number of sequences that can be
    partially prefilled concurrently."""

73
    max_long_partial_prefills: int = Field(default=1, ge=1)
74
75
76
77
78
79
80
81
82
    """For chunked prefill, the maximum number of prompts longer than
    long_prefill_token_threshold that will be prefilled concurrently. Setting
    this less than max_num_partial_prefills will allow shorter prompts to jump
    the queue in front of longer prompts in some cases, improving latency."""

    long_prefill_token_threshold: int = 0
    """For chunked prefill, a request is considered long if the prompt is
    longer than this number of tokens."""

83
    enable_chunked_prefill: bool = True
84
    """If True, prefill requests can be chunked based
85
86
87
88
89
    on the remaining `max_num_batched_tokens`.

    The default value here is mainly for convenience when testing.
    In real usage, this should be set in `EngineArgs.create_engine_config`.
    """
90
91
92
93
94

    is_multimodal_model: bool = False
    """True if the model is multimodal."""

    # TODO (ywang96): Make this configurable.
95
    max_num_encoder_input_tokens: int = Field(init=False)
96
97
98
99
100
101
    """Multimodal encoder compute budget, only used in V1.

    NOTE: This is not currently configurable. It will be overridden by
    max_num_batched_tokens in case max multimodal embedding size is larger."""

    # TODO (ywang96): Make this configurable.
102
    encoder_cache_size: int = Field(init=False)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    """Multimodal encoder cache size, only used in V1.

    NOTE: This is not currently configurable. It will be overridden by
    max_num_batched_tokens in case max multimodal embedding size is larger."""

    policy: SchedulerPolicy = "fcfs"
    """The scheduling policy to use:\n
    - "fcfs" means first come first served, i.e. requests are handled in order
    of arrival.\n
    - "priority" means requests are handled based on given priority (lower
    value means earlier handling) and time of arrival deciding any ties)."""

    disable_chunked_mm_input: bool = False
    """If set to true and chunked prefill is enabled, we do not want to
    partially schedule a multimodal item. Only used in V1
    This ensures that if a request has a mixed prompt
    (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
    some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
    it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""

123
124
    # scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
    # (default) or "mod.custom_class".
125
    scheduler_cls: str | type[object] | None = Field(default=None)
126
127
128
    """The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
    the default scheduler. Can be a class directly or the path to a class of
    form "mod.custom_class"."""
129

130
    disable_hybrid_kv_cache_manager: bool | None = None
131
132
133
    """If set to True, KV cache manager will allocate the same size of KV cache
    for all attention layers even if there are multiple type of attention layers
    like full attention and sliding window attention.
134
135
    If set to None, the default value will be determined based on the environment
    and starting configuration.
136
137
    """

138
    async_scheduling: bool | None = Field(default=None)
139
140
    """If set to False, disable async scheduling. Async scheduling helps to
    avoid gaps in GPU utilization, leading to better latency and throughput.
141
142
    """

143
144
145
146
147
148
    stream_interval: int = Field(default=1, ge=1)
    """The interval (or buffer size) for streaming in terms of token length.
    A smaller value (1) makes streaming smoother by sending each token immediately,
    while a larger value (e.g., 10) reduces host overhead and may increase throughput
    by batching multiple tokens before sending."""

149
150
151
152
153
154
155
156
157
158
159
    @staticmethod
    def default_factory(**kwargs):
        """
        Factory method to create `SchedulerConfig` with default values for `InitVar`s.
        """
        if "max_model_len" not in kwargs:
            kwargs["max_model_len"] = 8192
        if "is_encoder_decoder" not in kwargs:
            kwargs["is_encoder_decoder"] = False
        return SchedulerConfig(**kwargs)

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    def get_scheduler_cls(self) -> type["SchedulerInterface"]:
        if self.scheduler_cls is None:
            if self.async_scheduling:
                from vllm.v1.core.sched.async_scheduler import AsyncScheduler

                return AsyncScheduler
            from vllm.v1.core.sched.scheduler import Scheduler

            return Scheduler

        # This warning can be removed once the Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        logger.warning_once(
            "Using custom scheduler class %s. This scheduler interface is "
            "not public and compatibility may not be maintained.",
            self.scheduler_cls,
        )
        if not isinstance(self.scheduler_cls, str):
            return cast(type["SchedulerInterface"], self.scheduler_cls)
        return resolve_obj_by_qualname(self.scheduler_cls)

182
183
184
185
186
187
188
189
190
191
192
193
194
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
195
196
197
198
199
200
201
202
203
204
205
206

        # max_num_batched_tokens need to be included in the hash due
        # to two reasons:
        # 1. LoRA creates static buffers based on max_num_batched_tokens.
        #   The tensor sizes and strides get captured in the torch.compile
        #   graph explicitly.
        # 2. Inductor decides whether using 32-bit or 64-bit indexing integer
        #   based on the data sizes. `max_num_batched_tokens` has an
        #   impact on that. For more details, please check
        #   https://github.com/vllm-project/vllm/issues/29585
        factors.append(self.max_num_batched_tokens)

207
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
208
209
        return hash_str

210
    @field_validator("scheduler_cls", "async_scheduling", mode="wrap")
211
212
213
    @classmethod
    def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
        """Skip validation if the value is `None` when initialisation is delayed."""
214
        return None if value is None else handler(value)
215

216
    def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None:
217
218
219
220
221
222
223
        if is_encoder_decoder:
            # Chunked prefill should be disabled for encoder-decoder models.
            self.disable_chunked_mm_input = True
            self.enable_chunked_prefill = False
            self.long_prefill_token_threshold = 0
            logger.info(
                "Encoder-decoder models do not support chunked prefill nor"
224
225
                " prefix caching; disabling both."
            )
226

227
228
229
230
231
232
        self.max_num_encoder_input_tokens = self.max_num_batched_tokens
        self.encoder_cache_size = self.max_num_batched_tokens

        if self.enable_chunked_prefill:
            logger.info(
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
233
234
                self.max_num_batched_tokens,
            )
235
236
237

        if self.max_num_partial_prefills > 1:
            if self.long_prefill_token_threshold == 0:
238
                self.long_prefill_token_threshold = int(max_model_len * 0.04)
239
240
241
242
243

            logger.info(
                "Concurrent partial prefills enabled with "
                "max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
                "long_prefill_token_threshold=%d",
244
245
246
247
                self.max_num_partial_prefills,
                self.max_long_partial_prefills,
                self.long_prefill_token_threshold,
            )
248

249
250
251
        self.verify_max_model_len(max_model_len)

    def verify_max_model_len(self, max_model_len: int) -> Self:
252
        if (
253
            self.max_num_batched_tokens < max_model_len
254
            and not self.enable_chunked_prefill
255
        ):
256
257
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
258
                f"smaller than max_model_len ({max_model_len}). "
259
260
261
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
262
263
                "decrease max_model_len."
            )
264
265
266
267
268

        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
269
270
                f"({self.max_num_seqs})."
            )
271

272
        if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
273
274
275
276
            logger.warning(
                "max_num_batched_tokens (%d) exceeds max_num_seqs "
                "* max_model_len (%d). This may lead to unexpected behavior.",
                self.max_num_batched_tokens,
277
                self.max_num_seqs * max_model_len,
278
            )
279

280
        if self.max_num_partial_prefills > 1:
281
            if not self.enable_chunked_prefill:
282
283
284
285
                raise ValueError(
                    "Chunked prefill must be enabled to set "
                    "max_num_partial_prefills > 1."
                )
286

287
            if self.long_prefill_token_threshold > max_model_len:
288
289
290
                raise ValueError(
                    "long_prefill_token_threshold "
                    f"({self.long_prefill_token_threshold}) cannot be greater "
291
                    f"than the max_model_len ({max_model_len})."
292
                )
293

294
        if self.max_long_partial_prefills > self.max_num_partial_prefills:
295
            raise ValueError(
296
297
                f"{self.max_long_partial_prefills=} must be less than or equal to "
                f"{self.max_num_partial_prefills=}."
298
            )
299
300

        return self