scheduler.py 13 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import hashlib
5
6
from collections.abc import Callable
from dataclasses import InitVar
7
from typing import TYPE_CHECKING, Any, Literal, cast
8

9
from pydantic import Field, field_validator, model_validator
10
11
12
13
14
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from vllm.config.utils import config
from vllm.logger import init_logger
15
16
17
18
19
from vllm.utils import (
    DEFAULT_MAX_NUM_BATCHED_TOKENS,
    MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
    POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
20
21
22
23
from vllm.utils.import_utils import resolve_obj_by_qualname

if TYPE_CHECKING:
    from vllm.v1.core.sched.interface import SchedulerInterface
24
25
26

logger = init_logger(__name__)

27
RunnerType = Literal["generate", "pooling", "draft"]
28
29
30
31
32
33
34
35
36
37
38
SchedulerPolicy = Literal["fcfs", "priority"]


@config
@dataclass
class SchedulerConfig:
    """Scheduler configuration."""

    runner_type: RunnerType = "generate"
    """The runner type to launch for the model."""

39
    max_num_batched_tokens: int = Field(default=None, ge=1)
40
41
42
43
44
    """Maximum number of tokens to be processed in a single iteration.

    This config has no static default. If left unspecified by the user, it will
    be set in `EngineArgs.create_engine_config` based on the usage context."""

45
    max_num_seqs: int = Field(default=None, ge=1)
46
47
48
49
50
    """Maximum number of sequences to be processed in a single iteration.

    This config has no static default. If left unspecified by the user, it will
    be set in `EngineArgs.create_engine_config` based on the usage context."""

51
    max_model_len: int = Field(default=None, ge=1)
52
53
54
55
    """Maximum length of a sequence (including prompt and generated text). This
    is primarily set in `ModelConfig` and that value should be manually
    duplicated here."""

56
    max_num_partial_prefills: int = Field(default=1, ge=1)
57
58
59
    """For chunked prefill, the maximum number of sequences that can be
    partially prefilled concurrently."""

60
    max_long_partial_prefills: int = Field(default=1, ge=1)
61
62
63
64
65
66
67
68
69
    """For chunked prefill, the maximum number of prompts longer than
    long_prefill_token_threshold that will be prefilled concurrently. Setting
    this less than max_num_partial_prefills will allow shorter prompts to jump
    the queue in front of longer prompts in some cases, improving latency."""

    long_prefill_token_threshold: int = 0
    """For chunked prefill, a request is considered long if the prompt is
    longer than this number of tokens."""

70
    num_lookahead_slots: int = Field(default=0, ge=0)
71
72
73
74
75
76
77
78
    """The number of slots to allocate per sequence per
    step, beyond the known token ids. This is used in speculative
    decoding to store KV activations of tokens which may or may not be
    accepted.

    NOTE: This will be replaced by speculative config in the future; it is
    present to enable correctness tests until then."""

79
    enable_chunked_prefill: bool = Field(default=None)
80
81
82
83
84
85
    """If True, prefill requests can be chunked based
    on the remaining max_num_batched_tokens."""

    is_multimodal_model: bool = False
    """True if the model is multimodal."""

86
87
88
89
90
91
92
    is_encoder_decoder: InitVar[bool] = False
    """True if the model is an encoder-decoder model.

    Note: This is stored in the ModelConfig, and is used only here to
    disable chunked prefill and prefix caching for encoder-decoder models.
    """

93
    # TODO (ywang96): Make this configurable.
94
    max_num_encoder_input_tokens: int = Field(init=False)
95
96
97
98
99
100
    """Multimodal encoder compute budget, only used in V1.

    NOTE: This is not currently configurable. It will be overridden by
    max_num_batched_tokens in case max multimodal embedding size is larger."""

    # TODO (ywang96): Make this configurable.
101
    encoder_cache_size: int = Field(init=False)
102
103
104
105
106
107
108
109
110
111
112
113
    """Multimodal encoder cache size, only used in V1.

    NOTE: This is not currently configurable. It will be overridden by
    max_num_batched_tokens in case max multimodal embedding size is larger."""

    policy: SchedulerPolicy = "fcfs"
    """The scheduling policy to use:\n
    - "fcfs" means first come first served, i.e. requests are handled in order
    of arrival.\n
    - "priority" means requests are handled based on given priority (lower
    value means earlier handling) and time of arrival deciding any ties)."""

114
    chunked_prefill_enabled: bool = Field(init=False)
115
116
117
118
119
120
121
122
123
124
    """True if chunked prefill is enabled."""

    disable_chunked_mm_input: bool = False
    """If set to true and chunked prefill is enabled, we do not want to
    partially schedule a multimodal item. Only used in V1
    This ensures that if a request has a mixed prompt
    (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
    some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
    it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""

125
126
    # scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
    # (default) or "mod.custom_class".
127
    scheduler_cls: str | type[object] = Field(default=None)
128
129
130
    """The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
    the default scheduler. Can be a class directly or the path to a class of
    form "mod.custom_class"."""
131
132
133
134
135
136
137
138

    disable_hybrid_kv_cache_manager: bool = False
    """If set to True, KV cache manager will allocate the same size of KV cache
    for all attention layers even if there are multiple type of attention layers
    like full attention and sliding window attention.
    """

    async_scheduling: bool = False
139
140
141
142
    """If set to True, perform async scheduling. This helps to avoid gaps in
    GPU utilization, leading to better latency and throughput.
    Async scheduling is currently not supported with some features such as
    speculative decoding and pipeline parallelism.
143
144
    """

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    def get_scheduler_cls(self) -> type["SchedulerInterface"]:
        if self.scheduler_cls is None:
            if self.async_scheduling:
                from vllm.v1.core.sched.async_scheduler import AsyncScheduler

                return AsyncScheduler
            from vllm.v1.core.sched.scheduler import Scheduler

            return Scheduler

        # This warning can be removed once the Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        logger.warning_once(
            "Using custom scheduler class %s. This scheduler interface is "
            "not public and compatibility may not be maintained.",
            self.scheduler_cls,
        )
        if not isinstance(self.scheduler_cls, str):
            return cast(type["SchedulerInterface"], self.scheduler_cls)
        return resolve_obj_by_qualname(self.scheduler_cls)

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: list[Any] = []
182
        hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
183
184
        return hash_str

185
186
187
188
189
    @field_validator(
        "max_num_batched_tokens",
        "max_num_seqs",
        "max_model_len",
        "enable_chunked_prefill",
190
191
        "scheduler_cls",
        "async_scheduling",
192
193
194
195
196
197
198
199
200
        mode="wrap",
    )
    @classmethod
    def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
        """Skip validation if the value is `None` when initialisation is delayed."""
        if value is None:
            return value
        return handler(value)

201
    def __post_init__(self, is_encoder_decoder: bool) -> None:
202
203
204
205
206
207
        if self.max_model_len is None:
            self.max_model_len = 8192

        if self.max_num_seqs is None:
            self.max_num_seqs = 128

208
209
210
211
212
213
214
215
        if is_encoder_decoder:
            # Chunked prefill should be disabled for encoder-decoder models.
            self.disable_chunked_mm_input = True
            self.chunked_prefill_enabled = False
            self.enable_chunked_prefill = False
            self.long_prefill_token_threshold = 0
            logger.info(
                "Encoder-decoder models do not support chunked prefill nor"
216
217
                " prefix caching; disabling both."
            )
218

219
220
        if self.max_num_batched_tokens is None:
            if self.enable_chunked_prefill:
221
                self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
222
223
224
225
226
            else:
                # If max_model_len is too short, use
                # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
                # for higher throughput.
                self.max_num_batched_tokens = max(
227
228
                    self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
                )
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

            if self.runner_type == "pooling":
                # Choose specific value for higher throughput
                self.max_num_batched_tokens = max(
                    self.max_num_batched_tokens,
                    POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
                )
            if self.is_multimodal_model:
                # The value needs to be at least the number of multimodal tokens
                self.max_num_batched_tokens = max(
                    self.max_num_batched_tokens,
                    MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
                )

            # When using default settings,
            # Ensure max_num_batched_tokens does not exceed model limit.
            # Some models (e.g., Whisper) have embeddings tied to max length.
            self.max_num_batched_tokens = min(
247
248
                self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
            )
249
250
251
252
253
254
255

        self.max_num_encoder_input_tokens = self.max_num_batched_tokens
        self.encoder_cache_size = self.max_num_batched_tokens

        if self.enable_chunked_prefill:
            logger.info(
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
256
257
                self.max_num_batched_tokens,
            )
258
259
260
261

        self.chunked_prefill_enabled = self.enable_chunked_prefill
        if self.max_num_partial_prefills > 1:
            if self.long_prefill_token_threshold == 0:
262
                self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
263
264
265
266
267

            logger.info(
                "Concurrent partial prefills enabled with "
                "max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
                "long_prefill_token_threshold=%d",
268
269
270
271
                self.max_num_partial_prefills,
                self.max_long_partial_prefills,
                self.long_prefill_token_threshold,
            )
272

273
    @model_validator(mode="after")
274
    def _verify_args(self) -> Self:
275
276
277
278
        if (
            self.max_num_batched_tokens < self.max_model_len
            and not self.chunked_prefill_enabled
        ):
279
280
281
282
283
284
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
285
286
                "decrease max_model_len."
            )
287
288
289
290
291

        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
292
293
                f"({self.max_num_seqs})."
            )
294
295
296
297
298
299

        if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
            logger.warning(
                "max_num_batched_tokens (%d) exceeds max_num_seqs "
                "* max_model_len (%d). This may lead to unexpected behavior.",
                self.max_num_batched_tokens,
300
301
                self.max_num_seqs * self.max_model_len,
            )
302

303
        if self.max_num_partial_prefills > 1:
304
            if not self.chunked_prefill_enabled:
305
306
307
308
                raise ValueError(
                    "Chunked prefill must be enabled to set "
                    "max_num_partial_prefills > 1."
                )
309
310
311
312
313

            if self.long_prefill_token_threshold > self.max_model_len:
                raise ValueError(
                    "long_prefill_token_threshold "
                    f"({self.long_prefill_token_threshold}) cannot be greater "
314
315
                    f"than the max_model_len ({self.max_model_len})."
                )
316

317
        if self.max_long_partial_prefills > self.max_num_partial_prefills:
318
            raise ValueError(
319
320
                f"{self.max_long_partial_prefills=} must be less than or equal to "
                f"{self.max_num_partial_prefills=}."
321
            )
322
323

        return self