__init__.py 38.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# ruff: noqa: F401
5
import ast
6
import copy
7
import hashlib
8
import inspect
9
import json
10
import os
11
import textwrap
12
from contextlib import contextmanager
13
from dataclasses import field, fields, is_dataclass, replace
14
from functools import cached_property, lru_cache
15
from pathlib import Path
16
17
from typing import (TYPE_CHECKING, Any, Literal, Optional, Protocol, TypeVar,
                    Union, cast)
18

19
import regex as re
20
import torch
21
from pydantic import ConfigDict, SkipValidation
22
from pydantic.dataclasses import dataclass
23
from typing_extensions import runtime_checkable
24

25
import vllm.envs as envs
26
from vllm import version
27
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
28
                               PrefixCachingHashAlgo)
29
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
30
                                     CUDAGraphMode, PassConfig)
31
from vllm.config.device import Device, DeviceConfig
32
from vllm.config.kv_events import KVEventsConfig
33
from vllm.config.kv_transfer import KVTransferConfig
34
from vllm.config.load import LoadConfig
35
from vllm.config.lora import LoRAConfig
36
37
38
39
40
from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode,
                               ModelConfig, ModelDType, ModelImpl,
                               RunnerOption, TaskOption, TokenizerMode,
                               iter_architecture_defaults,
                               try_match_architecture_defaults)
41
42
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
                                    MultiModalConfig)
43
from vllm.config.observability import DetailedTraceModules, ObservabilityConfig
44
45
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
                                  ParallelConfig)
46
from vllm.config.pooler import PoolerConfig
47
from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy
48
from vllm.config.speculative import SpeculativeConfig
49
from vllm.config.speech_to_text import SpeechToTextConfig
50
from vllm.config.structured_outputs import StructuredOutputsConfig
51
from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field
Woosuk Kwon's avatar
Woosuk Kwon committed
52
from vllm.logger import init_logger
53
from vllm.multimodal import MULTIMODAL_REGISTRY
54
55
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
56

57
if TYPE_CHECKING:
58
    from _typeshed import DataclassInstance
59
    from transformers.configuration_utils import PretrainedConfig
60

61
62
63
    from vllm.model_executor.layers.quantization.base_config import (
        QuantizationConfig)
else:
64
    DataclassInstance = Any
65
    PretrainedConfig = Any
66
    QuantizationConfig = Any
67
    QuantizationMethods = Any
68
    BaseModelLoader = Any
69
    LogitsProcessor = Any
70

71
logger = init_logger(__name__)
72
DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance)
73

74

75
@runtime_checkable
76
77
78
79
80
81
class SupportsHash(Protocol):

    def compute_hash(self) -> str:
        ...


82
83
class SupportsMetricsInfo(Protocol):

84
    def metrics_info(self) -> dict[str, str]:
85
86
87
        ...


88
@config
89
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
90
91
class VllmConfig:
    """Dataclass which contains all vllm-related configuration. This
92
93
94
    simplifies passing around the distinct configurations in the codebase.
    """

95
96
97
    # TODO: use default_factory once default constructing ModelConfig doesn't
    # try to download a model
    model_config: ModelConfig = None  # type: ignore
98
99
100
101
102
103
104
105
106
107
108
    """Model configuration."""
    cache_config: CacheConfig = field(default_factory=CacheConfig)
    """Cache configuration."""
    parallel_config: ParallelConfig = field(default_factory=ParallelConfig)
    """Parallel configuration."""
    scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig)
    """Scheduler configuration."""
    device_config: DeviceConfig = field(default_factory=DeviceConfig)
    """Device configuration."""
    load_config: LoadConfig = field(default_factory=LoadConfig)
    """Load configuration."""
109
    lora_config: Optional[LoRAConfig] = None
110
111
112
    """LoRA configuration."""
    speculative_config: Optional[SpeculativeConfig] = None
    """Speculative decoding configuration."""
113
114
115
    structured_outputs_config: StructuredOutputsConfig = field(
        default_factory=StructuredOutputsConfig)
    """Structured outputs configuration."""
116
    observability_config: Optional[ObservabilityConfig] = None
117
    """Observability configuration."""
118
    quant_config: Optional[QuantizationConfig] = None
119
120
121
    """Quantization configuration."""
    compilation_config: CompilationConfig = field(
        default_factory=CompilationConfig)
122
    """`torch.compile` and cudagraph capture configuration for the model.
123

124
125
    As a shorthand, `-O<n>` can be used to directly specify the compilation
    level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
126
    Currently, -O <n> and -O=<n> are supported as well but this will likely be
127
    removed in favor of clearer -O<n> syntax in the future.
128
129
130

    NOTE: level 0 is the default level without any optimization. level 1 and 2
    are for internal testing only. level 3 is the recommended level for
131
    production, also default in V1.
132
133
134
135
136
137

    You can specify the full compilation config like so:
    `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
    """
    kv_transfer_config: Optional[KVTransferConfig] = None
    """The configurations for distributed KV cache transfer."""
138
    kv_events_config: Optional[KVEventsConfig] = None
139
    """The configurations for event publishing."""
140
    # some opaque config, only used to provide additional information
141
142
    # for the hash computation, mainly used for testing, debugging or out of
    # tree config registration.
143
144
145
146
    additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
    """Additional config for specified platform. Different platforms may
    support different configs. Make sure the configs are valid for the platform
    you are using. Contents must be hashable."""
147
    instance_id: str = ""
148
    """The ID of the vLLM instance."""
149

150
151
152
153
154
155
156
157
158
159
160
161
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
162
        factors: list[Any] = []
163
164

        # summarize vllm config
165
        vllm_factors: list[Any] = []
166
167
        from vllm import __version__
        vllm_factors.append(__version__)
168
        vllm_factors.append(envs.VLLM_USE_V1)
169
170
        if self.model_config:
            vllm_factors.append(self.model_config.compute_hash())
171
172
        else:
            vllm_factors.append("None")
173
174
        if self.cache_config:
            vllm_factors.append(self.cache_config.compute_hash())
175
176
        else:
            vllm_factors.append("None")
177
178
        if self.parallel_config:
            vllm_factors.append(self.parallel_config.compute_hash())
179
180
        else:
            vllm_factors.append("None")
181
182
        if self.scheduler_config:
            vllm_factors.append(self.scheduler_config.compute_hash())
183
184
        else:
            vllm_factors.append("None")
185
186
        if self.device_config:
            vllm_factors.append(self.device_config.compute_hash())
187
188
        else:
            vllm_factors.append("None")
189
190
        if self.load_config:
            vllm_factors.append(self.load_config.compute_hash())
191
192
        else:
            vllm_factors.append("None")
193
194
        if self.lora_config:
            vllm_factors.append(self.lora_config.compute_hash())
195
196
197
198
199
            # LoRA creates static buffers based on max_num_batched_tokens.
            # The tensor sizes and strides get captured in the torch.compile
            # graph explicitly.
            vllm_factors.append(
                str(self.scheduler_config.max_num_batched_tokens))
200
201
        else:
            vllm_factors.append("None")
202
203
        if self.speculative_config:
            vllm_factors.append(self.speculative_config.compute_hash())
204
205
        else:
            vllm_factors.append("None")
206
207
        if self.structured_outputs_config:
            vllm_factors.append(self.structured_outputs_config.compute_hash())
208
209
        else:
            vllm_factors.append("None")
210
211
        if self.observability_config:
            vllm_factors.append(self.observability_config.compute_hash())
212
213
        else:
            vllm_factors.append("None")
214
215
216
217
        if self.quant_config:
            pass  # should be captured by model_config.quantization
        if self.compilation_config:
            vllm_factors.append(self.compilation_config.compute_hash())
218
219
        else:
            vllm_factors.append("None")
220
221
        if self.kv_transfer_config:
            vllm_factors.append(self.kv_transfer_config.compute_hash())
222
223
224
        else:
            vllm_factors.append("None")
        if self.additional_config:
225
226
227
228
229
230
231
232
            if isinstance(additional_config := self.additional_config, dict):
                additional_config_hash = hashlib.md5(
                    json.dumps(additional_config, sort_keys=True).encode(),
                    usedforsecurity=False,
                ).hexdigest()
            else:
                additional_config_hash = additional_config.compute_hash()
            vllm_factors.append(additional_config_hash)
233
234
        else:
            vllm_factors.append("None")
235
236
        factors.append(vllm_factors)

237
238
        hash_str = hashlib.md5(str(factors).encode(),
                               usedforsecurity=False).hexdigest()[:10]
239
240
        return hash_str

241
242
243
244
245
246
    def pad_for_cudagraph(self, batch_size: int) -> int:
        # if batch_size > self.compilation_config.max_capture_size,
        # it should raise an IndexError.
        # the caller should make sure the batch_size is within the range,
        # i.e., batch_size <= self.compilation_config.max_capture_size
        return self.compilation_config.bs_to_padded_graph_size[batch_size]
247

248
249
250
251
252
    @staticmethod
    def _get_quantization_config(
            model_config: ModelConfig,
            load_config: LoadConfig) -> Optional[QuantizationConfig]:
        """Get the quantization config."""
253
        from vllm.platforms import current_platform
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        if model_config.quantization is not None:
            from vllm.model_executor.model_loader.weight_utils import (
                get_quant_config)
            quant_config = get_quant_config(model_config, load_config)
            capability_tuple = current_platform.get_device_capability()

            if capability_tuple is not None:
                capability = capability_tuple.to_int()
                if capability < quant_config.get_min_capability():
                    raise ValueError(
                        f"The quantization method {model_config.quantization} "
                        "is not supported for the current GPU. Minimum "
                        f"capability: {quant_config.get_min_capability()}. "
                        f"Current capability: {capability}.")
            supported_dtypes = quant_config.get_supported_act_dtypes()
            if model_config.dtype not in supported_dtypes:
                raise ValueError(
                    f"{model_config.dtype} is not supported for quantization "
                    f"method {model_config.quantization}. Supported dtypes: "
                    f"{supported_dtypes}")
274
            quant_config.maybe_update_config(model_config.model)
275
276
            return quant_config
        return None
277

278
279
280
281
282
283
284
285
286
287
288
    @staticmethod
    def get_quantization_config(
            model_config: ModelConfig,
            load_config: LoadConfig) -> Optional[QuantizationConfig]:
        import copy

        # For some reason, the _ version of this modifies the model_config
        # object, so using deepcopy to avoid this problem.
        return VllmConfig._get_quantization_config(copy.deepcopy(model_config),
                                                   load_config)

289
290
291
292
293
294
295
296
297
    def with_hf_config(
        self,
        hf_config: PretrainedConfig,
        architectures: Optional[list[str]] = None,
    ) -> "VllmConfig":
        if architectures is not None:
            hf_config = copy.deepcopy(hf_config)
            hf_config.architectures = architectures

298
299
300
301
302
        model_config = copy.deepcopy(self.model_config)
        model_config.hf_config = hf_config

        return replace(self, model_config=model_config)

303
304
305
    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
306
307
308

        self.try_verify_and_update_config()

309
310
        if self.model_config is not None:
            self.model_config.verify_with_parallel_config(self.parallel_config)
311
312
            self.model_config.verify_dual_chunk_attention_config(
                self.load_config)
313

314
        self.cache_config.verify_with_parallel_config(self.parallel_config)
315

316
        if self.lora_config is not None:
317
            self.lora_config.verify_with_cache_config(self.cache_config)
318
            self.lora_config.verify_with_model_config(self.model_config)
319

320
        if self.quant_config is None and self.model_config is not None:
321
322
            self.quant_config = VllmConfig._get_quantization_config(
                self.model_config, self.load_config)
323

324
        from vllm.platforms import current_platform
325
        if self.model_config is not None and \
326
327
328
            self.scheduler_config.chunked_prefill_enabled and \
            self.model_config.dtype == torch.float32 and \
            current_platform.get_device_capability() == (7, 5):
329
            logger.warning_once(
330
331
332
333
                "Turing devices tensor cores do not support float32 matmul. "
                "To workaround this limitation, vLLM will set 'ieee' input "
                "precision for chunked prefill triton kernels.")

334
335
336
337
338
339
340
341
342
343
344
        # If the user does not explicitly set a compilation level, then
        # we use the default level. The default level depends on other
        # settings (see the below code).
        if self.compilation_config.level is None:
            if envs.VLLM_USE_V1:
                if (self.model_config is not None
                        and not self.model_config.enforce_eager):
                    self.compilation_config.level = CompilationLevel.PIECEWISE
                else:
                    self.compilation_config.level = \
                            CompilationLevel.NO_COMPILATION
345

346
347
348
349
350
            else:
                # NB: Passing both --enforce-eager and a compilation level
                # in V0 means the compilation level wins out.
                self.compilation_config.level = CompilationLevel.NO_COMPILATION

351
352
353
354
355
        # async tp is built on top of sequence parallelism
        # and requires it to be enabled.
        if self.compilation_config.pass_config.enable_async_tp:
            self.compilation_config.pass_config.enable_sequence_parallelism = \
                True
356
357
        if self.compilation_config.pass_config.enable_sequence_parallelism:
            self.compilation_config.custom_ops.append("+rms_norm")
358

359
        if current_platform.support_static_graph_mode():
360
361
362
363
364
            # if cudagraph_mode is not explicitly set by users, set default
            # value
            if self.compilation_config.cudagraph_mode is None:
                if envs.VLLM_USE_V1 and self.compilation_config.level \
                    == CompilationLevel.PIECEWISE:
365
                    # default to full and piecewise for most models
366
                    self.compilation_config.cudagraph_mode = \
367
368
                        CUDAGraphMode.FULL_AND_PIECEWISE

369
370
                    # pooling models and encoder-decoder models
                    # do not support full cudagraphs
371
                    if self.model_config is not None and \
372
373
                        (self.model_config.pooler_config is not None
                         or self.model_config.is_encoder_decoder):
374
375
                        self.compilation_config.cudagraph_mode = \
                            CUDAGraphMode.PIECEWISE
376
377
                else:
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
378

379
380
381
382
383
384
385
386
387
388
389
            # disable cudagraph when enforce eager execution
            if self.model_config is not None and \
                    self.model_config.enforce_eager:
                logger.info("Cudagraph is disabled under eager mode")
                self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
            elif envs.VLLM_USE_V1:
                self.compilation_config.cudagraph_num_of_warmups = 1

            self._set_cudagraph_sizes()
        else:
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
390

391
392
393
394
395
396
397
398
399
400
401
402
403
404
        if self.cache_config.kv_sharing_fast_prefill:

            if self.speculative_config is not None and \
                self.speculative_config.use_eagle():
                raise NotImplementedError(
                    "Fast prefill optimization for KV sharing is not "
                    "compatible with EAGLE as EAGLE requires correct logits "
                    "for all tokens while fast prefill gives incorrect logits "
                    "for prompt tokens.")

            logger.warning_once(
                "--kv-sharing-fast-prefill requires changes on model side for "
                "correctness and to realize prefill savings. ")

405
406
        disable_chunked_prefill_reasons: list[str] = []

407
408
409
410
411
412
413
        if self.model_config:
            if self.model_config.pooler_config:
                pooling_type = self.model_config.pooler_config.pooling_type
                if pooling_type is None or pooling_type.lower() != "last":
                    disable_chunked_prefill_reasons.append(
                        "Only \"last\" pooling supports chunked "
                        "prefill and prefix caching; disabling both.")
414
415
416
417
                if not getattr(self.model_config.hf_config, "is_causal", True):
                    disable_chunked_prefill_reasons.append(
                        "Only models using causal attention supports chunked "
                        "prefill and prefix caching; disabling both.")
418
419
420
421
422
423
424
425
            elif self.model_config.is_encoder_decoder:
                self.scheduler_config.max_num_encoder_input_tokens = \
                    MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
                logger.debug(
                    "Encoder-decoder model detected: setting "
                    "`max_num_encoder_input_tokens` to encoder length (%s)",
                    self.scheduler_config.max_num_encoder_input_tokens)
                self.scheduler_config.disable_chunked_mm_input = True
426
                disable_chunked_prefill_reasons.append(
427
428
429
430
431
432
433
434
435
436
437
                    "Encoder-decoder models do not support chunked prefill nor"
                    " prefix caching; disabling both.")
                if (self.model_config.architecture
                        == "WhisperForConditionalGeneration"
                        and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
                        != "spawn"):
                    logger.warning(
                        "Whisper is known to have issues with "
                        "forked workers. If startup is hanging, "
                        "try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
                        "to 'spawn'.")
438
439
440
441
442
443
444
445
446
447

        if disable_chunked_prefill_reasons:
            for reason in disable_chunked_prefill_reasons:
                logger.info(reason)
            self.scheduler_config.chunked_prefill_enabled = False
            self.scheduler_config.long_prefill_token_threshold = 0

            if self.cache_config is not None:
                self.cache_config.enable_prefix_caching = False

448
        if (self.kv_events_config is not None
449
450
451
452
453
                and self.kv_events_config.enable_kv_cache_events
                and not self.cache_config.enable_prefix_caching):
            logger.warning(
                "KV cache events are on, but prefix caching is not enabled."
                "Use --enable-prefix-caching to enable.")
454
455
        if (self.kv_events_config is not None
                and self.kv_events_config.publisher != "null"
456
457
458
459
460
                and not self.kv_events_config.enable_kv_cache_events):
            logger.warning("KV cache events are disabled,"
                           "but the scheduler is configured to publish them."
                           "Modify KVEventsConfig.enable_kv_cache_events"
                           "to True to enable.")
461
462
        current_platform.check_and_update_config(self)

463
464
465
466
467
468
        # Do this after all the updates to compilation_config.level
        if envs.VLLM_USE_V1 and \
            self.compilation_config.level == CompilationLevel.PIECEWISE:
            self.compilation_config.set_splitting_ops_for_v1()

        # final check of cudagraph mode after all possible updates
469
        if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
470
            if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\
471
                and self.model_config is not None and \
472
473
474
475
476
477
478
                not self.model_config.disable_cascade_attn and\
                not self.compilation_config.cudagraph_mode.\
                has_piecewise_cudagraphs():
                logger.warning_once(
                    "No piecewise cudagraph for executing cascade attention."
                    " Will fall back to eager execution if a batch runs "
                    "into cascade attentions")
479
480
481
482
483
484
485
486
487

            if self.compilation_config.cudagraph_mode\
                .requires_piecewise_compilation():
                assert self.compilation_config.level == \
                    CompilationLevel.PIECEWISE, \
                    "Compilation level should be CompilationLevel.PIECEWISE "\
                    "when cudagraph_mode piecewise cudagraphs is used, "\
                    f"cudagraph_mode={self.compilation_config.cudagraph_mode}"

488
489
490
491
492
493
            # final migrate the deprecated flags
            self.compilation_config.use_cudagraph = self.compilation_config.\
                cudagraph_mode!= CUDAGraphMode.NONE
            self.compilation_config.full_cuda_graph = self.compilation_config.\
                cudagraph_mode.has_full_cudagraphs()

494
495
        if self.parallel_config.enable_dbo:
            a2a_backend = envs.VLLM_ALL2ALL_BACKEND
496
497
498
499
500
501
502
            assert a2a_backend in \
                ["deepep_low_latency", "deepep_high_throughput"], \
            "Microbatching currently only supports the deepep_low_latency and "\
            f"deepep_high_throughput all2all backend. {a2a_backend} is not "\
            "supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\
            "variable to deepep_low_latency or deepep_high_throughput and "\
            "install the DeepEP kernels."
503

504
505
506
507
508
            if not self.model_config.disable_cascade_attn:
                self.model_config.disable_cascade_attn = True
                logger.warning_once(
                    "Disabling cascade attention when DBO is enabled.")

509
510
511
        if not self.instance_id:
            self.instance_id = random_uuid()[:5]

512
513
514
515
516
        if (envs.VLLM_USE_V1
                and not self.scheduler_config.disable_hybrid_kv_cache_manager):
            # logger should only print warning message for hybrid models. As we
            # can't know whether the model is hybrid or not now, so we don't log
            # warning message here and will log it later.
517
            if not current_platform.support_hybrid_kv_cache():
518
                # Hybrid KV cache manager is not supported on non-GPU platforms.
519
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
520
521
            if self.kv_transfer_config is not None:
                # Hybrid KV cache manager is not compatible with KV transfer.
522
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
523
524
            if self.kv_events_config is not None:
                # Hybrid KV cache manager is not compatible with KV events.
525
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
526
            if self.model_config is not None and \
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
                self.model_config.attention_chunk_size is not None:
                if self.speculative_config is not None and \
                    self.speculative_config.use_eagle():
                    # Hybrid KV cache manager is not yet supported with chunked
                    # local attention + eagle.
                    self.scheduler_config.disable_hybrid_kv_cache_manager = True
                elif \
                    not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
                    logger.warning(
                        "There is a latency regression when using chunked local"
                        " attention with the hybrid KV cache manager. Disabling"
                        " it, by default. To enable it, set the environment "
                        "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
                    )
                    # Hybrid KV cache manager is not yet supported with chunked
                    # local attention.
                    self.scheduler_config.disable_hybrid_kv_cache_manager = True
544

545
546
547
548
549
550
551
552
553
554
555
        if self.compilation_config.debug_dump_path:
            self.compilation_config.debug_dump_path = \
                self.compilation_config.debug_dump_path.absolute().expanduser()
        if envs.VLLM_DEBUG_DUMP_PATH is not None:
            env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
            if self.compilation_config.debug_dump_path:
                logger.warning(
                    "Config-specified debug dump path is overridden"
                    " by VLLM_DEBUG_DUMP_PATH to %s", env_path)
            self.compilation_config.debug_dump_path = env_path

556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
    def update_sizes_for_sequence_parallelism(self,
                                              possible_sizes: list) -> list:
        # remove the sizes that not multiple of tp_size when
        # enable sequence parallelism
        removed_sizes = [
            size for size in possible_sizes
            if size % self.parallel_config.tensor_parallel_size != 0
        ]
        if removed_sizes:
            logger.warning(
                "Batch sizes %s are removed because they are not "
                "multiple of tp_size %d when "
                "sequence parallelism is enabled", removed_sizes,
                self.parallel_config.tensor_parallel_size)

        return [
            size for size in possible_sizes
            if size % self.parallel_config.tensor_parallel_size == 0
        ]

576
577
    def _set_cudagraph_sizes(self):
        """
578
579
        vLLM defines the default candidate list of batch sizes for CUDA graph
        capture as:
580

581
582
583
584
        ```python
        max_graph_size = min(max_num_seqs * 2, 512)
        # 1, 2, 4, then multiples of 8 up to max_graph_size
        cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size]
585

586
587
        In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
        will be the final sizes to capture cudagraph (in descending order).
588

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        These sizes are used to capture and reuse CUDA graphs for
        performance-critical paths (e.g., decoding). Capturing enables
        significantly faster kernel dispatch by avoiding Python overhead. The
        list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
        most GPUs), which controls the total allowed number of tokens in a
        batch. Since each sequence may have a variable number of tokens, the
        maximum usable batch size will depend on actual sequence lengths.

        Example:
            With `max_num_batched_tokens = 8192`, and typical sequences
            averaging ~32 tokens, most practical batch sizes fall below 256.
            However, the system will still allow capture sizes up to 512 if
            shape and memory permit.

        Note:
            If users explicitly specify cudagraph capture sizes in the
            compilation config, those will override this default logic.
            At runtime:

            - If batch size <= one of the `cudagraph_capture_sizes`, the closest
            padded CUDA graph will be used.
            - If batch size > largest `cudagraph_capture_sizes`, cudagraph will
            not be used.
612
613
614
        """

        # calculate the default `batch_size_capture_list`
615
616
617
618
619
620
621
        batch_size_capture_list = []
        if self.model_config is not None and \
            not self.model_config.enforce_eager:
            cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes
            if len(cuda_graph_sizes) == 1:
                batch_size_capture_list = [1, 2, 4] + [
                    i for i in range(8, cuda_graph_sizes[0] + 1, 8)
622
                ]
623
624
625
626
627
628
629
630
631
632
633
634
635
            elif len(cuda_graph_sizes) > 1:
                batch_size_capture_list = sorted(cuda_graph_sizes)
            else:
                raise TypeError(f"Invalid value for {cuda_graph_sizes=}.")
            if self.parallel_config.tensor_parallel_size > 1 and \
                self.compilation_config.pass_config.enable_sequence_parallelism:
                batch_size_capture_list = \
                    self.update_sizes_for_sequence_parallelism(batch_size_capture_list)
            max_num_tokens = self.scheduler_config.max_num_batched_tokens
            batch_size_capture_list = [
                size for size in batch_size_capture_list
                if size <= max_num_tokens
            ]
636
637
638
639

        self.compilation_config.init_with_cudagraph_sizes(
            batch_size_capture_list)

640
    def recalculate_max_model_len(self, max_model_len: int):
641
        # Can only be called in try_verify_and_update_config
642
        model_config = self.model_config
643
        max_model_len = model_config.get_and_verify_max_len(max_model_len)
644
645
        self.model_config.max_model_len = max_model_len
        self.scheduler_config.max_model_len = max_model_len
646
647

    def try_verify_and_update_config(self):
648
649
650
        if self.model_config is None:
            return

651
652
653
654
655
        # Avoid running try_verify_and_update_config multiple times
        if getattr(self.model_config, "config_updated", False):
            return
        self.model_config.config_updated = True

656
        architecture = self.model_config.architecture
657
658
659
        if architecture is None:
            return

660
661
        from vllm.model_executor.models.config import (
            MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig)
662
663
664
        cls = MODELS_CONFIG_MAP.get(architecture, None)
        if cls is not None:
            cls.verify_and_update_config(self)
665

666
667
668
        if self.model_config.is_hybrid:
            HybridAttentionMambaModelConfig.verify_and_update_config(self)

669
        if self.model_config.convert_type == "classify":
670
671
672
673
674
            # Maybe convert ForCausalLM into ForSequenceClassification model.
            from vllm.model_executor.models.adapters import (
                SequenceClassificationConfig)
            SequenceClassificationConfig.verify_and_update_config(self)

675
676
677
678
679
680
681
682
683
684
685
686
        if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
                self.model_config.model_weights):
            if self.load_config.load_format == "auto":
                logger.info("Detected Run:ai model config. "
                            "Overriding `load_format` to 'runai_streamer'")
                self.load_config.load_format = "runai_streamer"
            elif self.load_config.load_format != "runai_streamer":
                raise ValueError(f"To load a model from S3, 'load_format' "
                                 f"must be 'runai_streamer', "
                                 f"but got '{self.load_config.load_format}'. "
                                 f"Model: {self.model_config.model}")

687
688
689
690
691
692
693
694
695
696
697
698
699
700
    def compile_debug_dump_path(self) -> Optional[Path]:
        """Returns a rank-aware path for dumping 
        torch.compile debug information.
        """
        if self.compilation_config.debug_dump_path is None:
            return None
        tp_rank = self.parallel_config.rank
        dp_rank = self.parallel_config.data_parallel_rank
        data_parallel_size = self.parallel_config.data_parallel_size
        append_path = f"rank_{tp_rank}" if data_parallel_size == 1 \
            else f"rank_{tp_rank}_dp_{dp_rank}"
        path = self.compilation_config.debug_dump_path / append_path
        return path

701
    def __str__(self):
702
        return (
703
704
705
706
707
            f"model={self.model_config.model!r}, "
            f"speculative_config={self.speculative_config!r}, "
            f"tokenizer={self.model_config.tokenizer!r}, "
            f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
            f"tokenizer_mode={self.model_config.tokenizer_mode}, "
708
            f"revision={self.model_config.revision}, "
709
            f"tokenizer_revision={self.model_config.tokenizer_revision}, "
710
711
            f"trust_remote_code={self.model_config.trust_remote_code}, "
            f"dtype={self.model_config.dtype}, "
712
713
            f"max_seq_len={self.model_config.max_model_len}, "
            f"download_dir={self.load_config.download_dir!r}, "
714
            f"load_format={self.load_config.load_format}, "
715
716
            f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, "  # noqa
            f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, "  # noqa
717
            f"data_parallel_size={self.parallel_config.data_parallel_size}, "  # noqa
718
719
720
721
            f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, "  # noqa
            f"quantization={self.model_config.quantization}, "
            f"enforce_eager={self.model_config.enforce_eager}, "
            f"kv_cache_dtype={self.cache_config.cache_dtype}, "
722
            f"device_config={self.device_config.device}, "
723
            f"structured_outputs_config={self.structured_outputs_config!r}, "
724
725
726
727
728
            f"observability_config={self.observability_config!r}, "
            f"seed={self.model_config.seed}, "
            f"served_model_name={self.model_config.served_model_name}, "
            f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
            f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, "  # noqa
729
730
            f"pooler_config={self.model_config.pooler_config!r}, "
            f"compilation_config={self.compilation_config!r}")
731
732
733


_current_vllm_config: Optional[VllmConfig] = None
734
_current_prefix: Optional[str] = None
735
736
737


@contextmanager
738
739
740
def set_current_vllm_config(vllm_config: VllmConfig,
                            check_compile=False,
                            prefix: Optional[str] = None):
741
    """
742
    Temporarily set the current vLLM config.
743
    Used during model initialization.
744
    We save the current vLLM config in a global variable,
745
    so that all modules can access it, e.g. custom ops
746
    can access the vLLM config to determine how to dispatch.
747
    """
748
    global _current_vllm_config, _current_prefix
749
    old_vllm_config = _current_vllm_config
750
    old_prefix = _current_prefix
751
752
753
754
    from vllm.compilation.counter import compilation_counter
    num_models_seen = compilation_counter.num_models_seen
    try:
        _current_vllm_config = vllm_config
755
        _current_prefix = prefix
756
        yield
757
758
759
    except Exception:
        raise
    else:
760
761
762
        if check_compile:
            vllm_config.compilation_config.custom_op_log_check()

763
764
        if check_compile and \
            vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
765
766
767
768
769
770
771
772
773
            and compilation_counter.num_models_seen == num_models_seen:
            # If the model supports compilation,
            # compilation_counter.num_models_seen should be increased
            # by at least 1.
            # If it is not increased, it means the model does not support
            # compilation (does not have @support_torch_compile decorator).
            logger.warning(
                "`torch.compile` is turned on, but the model %s"
                " does not support it. Please open an issue on GitHub"
774
                " if you want it to be supported.",
775
                vllm_config.model_config.model)
776
    finally:
777
        _current_vllm_config = old_vllm_config
778
        _current_prefix = old_prefix
779
780
781
782
783
784
785
786
        # Clear the compilation config cache when context changes
        get_cached_compilation_config.cache_clear()


@lru_cache(maxsize=1)
def get_cached_compilation_config():
    """Cache config to avoid repeated calls to get_current_vllm_config()"""
    return get_current_vllm_config().compilation_config
787
788
789
790
791
792
793


def get_current_vllm_config() -> VllmConfig:
    if _current_vllm_config is None:
        # in ci, usually when we test custom ops/modules directly,
        # we don't set the vllm config. In that case, we set a default
        # config.
794
        logger.warning("Current vLLM config is not set.")
795
796
797
        from vllm.config import VllmConfig
        return VllmConfig()
    return _current_vllm_config
798
799


800
801
802
803
804
805
806
807
808
def get_current_model_prefix() -> str:
    """
    Get the prefix of the model that's currently being initialized.
    """
    assert _current_prefix is not None, \
        "Current model prefix is not set. "
    return _current_prefix


809
810
811
T = TypeVar("T")


812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
def get_layers_from_vllm_config(
        vllm_config: VllmConfig,
        layer_type: type[T],
        layer_names: Optional[list[str]] = None) -> dict[str, T]:
    """
    Get layers from the vLLM config.

    Args:
        vllm_config: The vLLM config.
        layer_type: The type of the layer to get.
        layer_names: The names of the layers to get. If None, return all layers.
    """

    if layer_names is None:
        layer_names = list(
            vllm_config.compilation_config.static_forward_context.keys())

    forward_context = vllm_config.compilation_config.static_forward_context

831
    return {
832
833
834
        layer_name: forward_context[layer_name]
        for layer_name in layer_names
        if isinstance(forward_context[layer_name], layer_type)
835
    }
836
837


838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
def update_config(config: DataclassInstanceT,
                  overrides: dict[str, Any]) -> DataclassInstanceT:
    processed_overrides = {}
    for field_name, value in overrides.items():
        assert hasattr(
            config, field_name), f"{type(config)} has no field `{field_name}`"
        current_value = getattr(config, field_name)
        if is_dataclass(current_value) and not is_dataclass(value):
            assert isinstance(value, dict), (
                f"Overrides to {type(config)}.{field_name} must be a dict"
                f"  or {type(current_value)}, but got {type(value)}")
            value = update_config(
                current_value,  # type: ignore[type-var]
                value)
        processed_overrides[field_name] = value
    return replace(config, **processed_overrides)