vllm.py 78.4 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
import getpass
6
7
import json
import os
8
9
import tempfile
import threading
10
import time
11
from contextlib import contextmanager
12
from dataclasses import is_dataclass
13
from datetime import datetime
14
from enum import IntEnum
15
16
from functools import lru_cache
from pathlib import Path
17
from typing import TYPE_CHECKING, Any, Literal, TypeVar, get_args
18
19

import torch
20
from pydantic import ConfigDict, Field, model_validator
21
22

import vllm.envs as envs
23
from vllm.logger import enable_trace_function_call, init_logger
24
25
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
26
from vllm.utils.hashing import safe_hash
27

28
from .attention import AttentionConfig
29
from .cache import CacheConfig
30
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
31
from .device import DeviceConfig
32
from .ec_transfer import ECTransferConfig
33
from .kernel import KernelConfig
34
35
36
37
38
39
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
40
from .offload import OffloadConfig
41
from .parallel import ParallelConfig
42
from .profiler import ProfilerConfig
43
from .scheduler import SchedulerConfig
44
from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig
45
from .structured_outputs import StructuredOutputsConfig
46
from .utils import SupportsHash, config, replace
47
from .weight_transfer import WeightTransferConfig
48
49
50
51

if TYPE_CHECKING:
    from transformers import PretrainedConfig

52
    from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
53
    from vllm.v1.kv_cache_interface import KVCacheConfig
54
55
56
57
58
else:
    PretrainedConfig = Any

    QuantizationConfig = Any

59
60
    KVCacheConfig = Any

61
62
63
logger = init_logger(__name__)


64
65
66
67
68
69
70
class OptimizationLevel(IntEnum):
    """Optimization level enum."""

    O0 = 0
    """O0 : No optimization. no compilation, no cudagraphs, no other
    optimization, just starting up immediately"""
    O1 = 1
71
    """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
72
73
74
75
76
77
78
    cudagraphs"""
    O2 = 2
    """O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
    O3 = 3
    """O3: Currently the same as -O2s."""


79
80
PerformanceMode = Literal["balanced", "interactivity", "throughput"]

81
82
83
84
85
86
87
88
89
90
IS_QUANTIZED = False
IS_DENSE = False
# The optimizations that depend on these properties currently set to False
# in all cases.
# if model_config is not None:
#     IS_QUANTIZED = lambda c: c.model_config.is_quantized()
#     IS_DENSE = lambda c: not c.model_config.is_model_moe()
# See https://github.com/vllm-project/vllm/issues/25689.


91
92
93
94
def enable_norm_fusion(cfg: "VllmConfig") -> bool:
    """Enable if either RMS norm or quant FP8 custom op is active;
    otherwise Inductor handles fusion."""

95
96
97
98
99
    return cfg.compilation_config.is_custom_op_enabled(
        "rms_norm"
    ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")


100
def enable_act_fusion(cfg: "VllmConfig") -> bool:
101
102
103
104
105
106
107
108
109
110
    """
    Enable if either SiLU+Mul or quant FP8 custom op is active;
    otherwise Inductor handles fusion.
    Also enable for FP4 models as FP4 quant is always custom so Inductor cannot fuse it.
    """
    return (
        cfg.compilation_config.is_custom_op_enabled("silu_and_mul")
        or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
        or (cfg.model_config is not None and cfg.model_config.is_nvfp4_quantized())
    )
111
112


113
def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
114
    """Enable if TP > 1 and Hopper/Blackwell and flashinfer installed."""
115
116
117
118
119
120
121
    from vllm.platforms import current_platform
    from vllm.utils.flashinfer import has_flashinfer

    return (
        cfg.parallel_config.tensor_parallel_size > 1
        and current_platform.is_cuda()
        and has_flashinfer()
122
        and (
123
            current_platform.is_device_capability_family(100)
124
125
126
127
128
            or current_platform.is_device_capability(90)
        )
        # tp-dp combination broken:
        # https://github.com/vllm-project/vllm/issues/34458
        and cfg.parallel_config.data_parallel_size == 1
129
130
131
        # tp-pp combination broken:
        # https://github.com/vllm-project/vllm/issues/35426
        and cfg.parallel_config.pipeline_parallel_size == 1
132
133
134
    )


135
136
137
138
139
140
141
142
143
144
145
146
147
def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
    """Enable if rotary embedding custom op is active and
    use_inductor_graph_partition is enabled.
    """
    from vllm._aiter_ops import rocm_aiter_ops

    return (
        rocm_aiter_ops.is_enabled()
        and cfg.compilation_config.is_custom_op_enabled("rotary_embedding")
        and cfg.compilation_config.use_inductor_graph_partition
    )


148
149
150
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
    """Enable if using AITER RMSNorm and AITER Triton GEMMs
    and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
151
    from vllm._aiter_ops import rocm_aiter_ops
152
153

    return (
154
155
        rocm_aiter_ops.is_rmsnorm_enabled()
        and not rocm_aiter_ops.is_triton_gemm_enabled()
156
        and cfg.model_config is not None
157
158
159
160
        and cfg.model_config.get_hidden_size() == 2880
    )


161
162
163
OPTIMIZATION_LEVEL_00 = {
    "compilation_config": {
        "pass_config": {
164
165
166
167
168
169
            "fuse_norm_quant": False,
            "fuse_act_quant": False,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": False,
            "enable_sp": False,
            "fuse_gemm_comms": False,
170
            "fuse_act_padding": False,
171
            "fuse_rope_kvcache": False,
172
173
174
175
        },
        "cudagraph_mode": CUDAGraphMode.NONE,
        "use_inductor_graph_partition": False,
    },
176
177
178
    "kernel_config": {
        "enable_flashinfer_autotune": False,
    },
179
180
181
182
}
OPTIMIZATION_LEVEL_01 = {
    "compilation_config": {
        "pass_config": {
183
184
185
186
187
188
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": False,
            "enable_sp": False,
            "fuse_gemm_comms": False,
189
            "fuse_act_padding": enable_norm_pad_fusion,
190
            "fuse_rope_kvcache": enable_rope_kvcache_fusion,
191
192
193
194
        },
        "cudagraph_mode": CUDAGraphMode.PIECEWISE,
        "use_inductor_graph_partition": False,
    },
195
196
197
    "kernel_config": {
        "enable_flashinfer_autotune": True,
    },
198
199
200
201
}
OPTIMIZATION_LEVEL_02 = {
    "compilation_config": {
        "pass_config": {
202
203
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
204
            "fuse_allreduce_rms": enable_allreduce_rms_fusion,
205
206
207
            "fuse_attn_quant": IS_QUANTIZED,
            "enable_sp": IS_DENSE,
            "fuse_gemm_comms": IS_DENSE,
208
            "fuse_act_padding": enable_norm_pad_fusion,
209
            "fuse_rope_kvcache": enable_rope_kvcache_fusion,
210
211
212
213
        },
        "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
        "use_inductor_graph_partition": False,
    },
214
215
216
    "kernel_config": {
        "enable_flashinfer_autotune": True,
    },
217
218
219
220
}
OPTIMIZATION_LEVEL_03 = {
    "compilation_config": {
        "pass_config": {
221
222
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
223
            "fuse_allreduce_rms": enable_allreduce_rms_fusion,
224
225
226
            "fuse_attn_quant": IS_QUANTIZED,
            "enable_sp": IS_DENSE,
            "fuse_gemm_comms": IS_DENSE,
227
            "fuse_act_padding": enable_norm_pad_fusion,
228
            "fuse_rope_kvcache": enable_rope_kvcache_fusion,
229
230
231
232
        },
        "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
        "use_inductor_graph_partition": False,
    },
233
234
235
    "kernel_config": {
        "enable_flashinfer_autotune": True,
    },
236
237
238
239
240
241
242
243
244
245
}

OPTIMIZATION_LEVEL_TO_CONFIG = {
    OptimizationLevel.O0: OPTIMIZATION_LEVEL_00,
    OptimizationLevel.O1: OPTIMIZATION_LEVEL_01,
    OptimizationLevel.O2: OPTIMIZATION_LEVEL_02,
    OptimizationLevel.O3: OPTIMIZATION_LEVEL_03,
}


246
247
@config(config=ConfigDict(arbitrary_types_allowed=True))  # type: ignore[arg-type,misc]
class VllmConfig:  # type: ignore[misc]
248
249
250
251
252
253
    """Dataclass which contains all vllm-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    # TODO: use default_factory once default constructing ModelConfig doesn't
    # try to download a model
254
    model_config: ModelConfig = Field(default=None)  # type: ignore[assignment]
255
    """Model configuration."""
256
    cache_config: CacheConfig = Field(default_factory=CacheConfig)
257
    """Cache configuration."""
258
    parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
259
    """Parallel configuration."""
260
261
262
    scheduler_config: SchedulerConfig = Field(
        default_factory=SchedulerConfig.default_factory,
    )
263
    """Scheduler configuration."""
264
    device_config: DeviceConfig = Field(default_factory=DeviceConfig)
265
    """Device configuration."""
266
    load_config: LoadConfig = Field(default_factory=LoadConfig)
267
    """Load configuration."""
268
269
    offload_config: OffloadConfig = Field(default_factory=OffloadConfig)
    """Model weight offloading configuration."""
270
271
    attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
    """Attention configuration."""
272
273
    kernel_config: KernelConfig = Field(default_factory=KernelConfig)
    """Kernel configuration."""
274
    lora_config: LoRAConfig | None = None
275
    """LoRA configuration."""
276
    speculative_config: SpeculativeConfig | None = None
277
    """Speculative decoding configuration."""
278
    structured_outputs_config: StructuredOutputsConfig = Field(
279
280
        default_factory=StructuredOutputsConfig
    )
281
    """Structured outputs configuration."""
282
283
284
    observability_config: ObservabilityConfig = Field(
        default_factory=ObservabilityConfig
    )
285
    """Observability configuration."""
286
    quant_config: QuantizationConfig | None = None
287
    """Quantization configuration."""
288
    compilation_config: CompilationConfig = Field(default_factory=CompilationConfig)
289
290
    """`torch.compile` and cudagraph capture configuration for the model.

291
292
    As a shorthand, one can append compilation arguments via
    -cc.parameter=argument such as `-cc.mode=3` (same as `-cc='{"mode":3}'`).
293
294

    You can specify the full compilation config like so:
295
    `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
296
    """
297
298
    profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
    """Profiling configuration."""
299
    kv_transfer_config: KVTransferConfig | None = None
300
    """The configurations for distributed KV cache transfer."""
301
    kv_events_config: KVEventsConfig | None = None
302
    """The configurations for event publishing."""
303
304
    ec_transfer_config: ECTransferConfig | None = None
    """The configurations for distributed EC cache transfer."""
305
306
307
    # some opaque config, only used to provide additional information
    # for the hash computation, mainly used for testing, debugging or out of
    # tree config registration.
308
    additional_config: dict | SupportsHash = Field(default_factory=dict)
309
310
311
312
313
    """Additional config for specified platform. Different platforms may
    support different configs. Make sure the configs are valid for the platform
    you are using. Contents must be hashable."""
    instance_id: str = ""
    """The ID of the vLLM instance."""
314
315
316
    optimization_level: OptimizationLevel = OptimizationLevel.O2
    """The optimization level. These levels trade startup time cost for
    performance, with -O0 having the best startup time and -O3 having the best
317
    performance. -O2 is used by default. See OptimizationLevel for full
318
    description."""
319

320
321
322
323
324
325
326
    performance_mode: PerformanceMode = "balanced"
    """Performance mode for runtime behavior, 'balanced' is the default.
    'interactivity' favors low end-to-end per-request latency at small batch
    sizes (fine-grained CUDA graphs, latency-oriented kernels).
    'throughput' favors aggregate tokens/sec at high concurrency (larger CUDA
    graphs, more aggressive batching, throughput-oriented kernels)."""

327
328
329
    weight_transfer_config: WeightTransferConfig | None = None
    """The configurations for weight transfer during RL training."""

330
331
332
333
334
335
    shutdown_timeout: int = Field(default=0, ge=0)
    """Shutdown grace period for in-flight requests. Shutdown will be delayed for
    up to this amount of time to allow already-running requests to complete. Any
    remaining requests are aborted once the timeout is reached.
    """

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []

        # summarize vllm config
        vllm_factors: list[Any] = []
        from vllm import __version__
353

354
355
356
        vllm_factors.append(__version__)
        if self.model_config:
            vllm_factors.append(self.model_config.compute_hash())
357
358
359
360
361
362
            if (
                self.compilation_config
                and getattr(self.compilation_config, "compile_mm_encoder", False)
                and self.model_config.multimodal_config
            ):
                vllm_factors.append(self.model_config.multimodal_config.compute_hash())
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        else:
            vllm_factors.append("None")
        if self.cache_config:
            vllm_factors.append(self.cache_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.parallel_config:
            vllm_factors.append(self.parallel_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.scheduler_config:
            vllm_factors.append(self.scheduler_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.device_config:
            vllm_factors.append(self.device_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.load_config:
            vllm_factors.append(self.load_config.compute_hash())
        else:
            vllm_factors.append("None")
385
386
387
388
        if self.offload_config:
            vllm_factors.append(self.offload_config.compute_hash())
        else:
            vllm_factors.append("None")
389
390
391
392
        if self.attention_config:
            vllm_factors.append(self.attention_config.compute_hash())
        else:
            vllm_factors.append("None")
393
394
395
396
397
398
399
400
401
402
        if self.lora_config:
            vllm_factors.append(self.lora_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.speculative_config:
            vllm_factors.append(self.speculative_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.structured_outputs_config:
            vllm_factors.append(self.structured_outputs_config.compute_hash())
403
404
        if self.profiler_config:
            vllm_factors.append(self.profiler_config.compute_hash())
405
406
        else:
            vllm_factors.append("None")
407
        vllm_factors.append(self.observability_config.compute_hash())
408
409
410
411
412
413
414
415
416
417
        if self.quant_config:
            pass  # should be captured by model_config.quantization
        if self.compilation_config:
            vllm_factors.append(self.compilation_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.kv_transfer_config:
            vllm_factors.append(self.kv_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
418
419
420
421
        if self.ec_transfer_config:
            vllm_factors.append(self.ec_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
422
423
        if self.additional_config:
            if isinstance(additional_config := self.additional_config, dict):
424
                additional_config_hash = safe_hash(
425
426
427
428
429
430
431
432
433
434
                    json.dumps(additional_config, sort_keys=True).encode(),
                    usedforsecurity=False,
                ).hexdigest()
            else:
                additional_config_hash = additional_config.compute_hash()
            vllm_factors.append(additional_config_hash)
        else:
            vllm_factors.append("None")
        factors.append(vllm_factors)

435
436
437
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
            :10
        ]
438
439
        return hash_str

440
441
442
443
444
445
446
447
448
    @property
    def num_speculative_tokens(self) -> int:
        if (
            self.speculative_config is not None
            and self.speculative_config.num_speculative_tokens is not None
        ):
            return self.speculative_config.num_speculative_tokens
        return 0

449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    @property
    def needs_dp_coordinator(self) -> bool:
        """
        Determine if the DPCoordinator process is needed.

        The DPCoordinator is needed in two cases:
        1. For MoE models with DP > 1: to handle wave coordination
           (even in external LB mode, since wave coordination runs in the coordinator)
        2. For non-MoE models in internal/hybrid LB mode: to collect and publish
           queue stats for load balancing across DP ranks

        Returns:
            True if DPCoordinator process is needed, False otherwise.
        """

        # For non-MoE models, only need coordinator in internal/hybrid LB mode
        # (for stats collection).
        return self.parallel_config.data_parallel_size > 1 and (
            self.model_config is None
            or self.model_config.is_moe
            or not self.parallel_config.data_parallel_external_lb
        )

472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    def enable_trace_function_call_for_thread(self) -> None:
        """
        Set up function tracing for the current thread,
        if enabled via the `VLLM_TRACE_FUNCTION` environment variable.
        """
        if envs.VLLM_TRACE_FUNCTION:
            tmp_dir = tempfile.gettempdir()
            # add username to tmp_dir to avoid permission issues
            tmp_dir = os.path.join(tmp_dir, getpass.getuser())
            filename = (
                f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                f"_thread_{threading.get_ident()}_at_{datetime.now()}.log"
            ).replace(" ", "_")
            log_path = os.path.join(
                tmp_dir,
                "vllm",
                f"vllm-instance-{self.instance_id}",
                filename,
            )
            os.makedirs(os.path.dirname(log_path), exist_ok=True)
            enable_trace_function_call(log_path)

494
495
    @staticmethod
    def _get_quantization_config(
496
        model_config: ModelConfig, load_config: LoadConfig
497
    ) -> QuantizationConfig | None:
498
499
        """Get the quantization config."""
        from vllm.platforms import current_platform
500

501
        if model_config.quantization is not None:
502
503
            from vllm.model_executor.model_loader.weight_utils import get_quant_config

504
505
506
507
508
509
510
511
512
513
            quant_config = get_quant_config(model_config, load_config)
            capability_tuple = current_platform.get_device_capability()

            if capability_tuple is not None:
                capability = capability_tuple.to_int()
                if capability < quant_config.get_min_capability():
                    raise ValueError(
                        f"The quantization method {model_config.quantization} "
                        "is not supported for the current GPU. Minimum "
                        f"capability: {quant_config.get_min_capability()}. "
514
515
                        f"Current capability: {capability}."
                    )
516
517
518
519
520
            supported_dtypes = quant_config.get_supported_act_dtypes()
            if model_config.dtype not in supported_dtypes:
                raise ValueError(
                    f"{model_config.dtype} is not supported for quantization "
                    f"method {model_config.quantization}. Supported dtypes: "
521
522
                    f"{supported_dtypes}"
                )
523
524
525
526
527
528
            quant_config.maybe_update_config(model_config.model)
            return quant_config
        return None

    @staticmethod
    def get_quantization_config(
529
        model_config: ModelConfig, load_config: LoadConfig
530
    ) -> QuantizationConfig | None:
531
532
533
534
        import copy

        # For some reason, the _ version of this modifies the model_config
        # object, so using deepcopy to avoid this problem.
535
536
537
        return VllmConfig._get_quantization_config(
            copy.deepcopy(model_config), load_config
        )
538
539
540
541

    def with_hf_config(
        self,
        hf_config: PretrainedConfig,
542
        architectures: list[str] | None = None,
543
544
545
546
547
548
    ) -> "VllmConfig":
        if architectures is not None:
            hf_config = copy.deepcopy(hf_config)
            hf_config.architectures = architectures

        model_config = copy.deepcopy(self.model_config)
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572

        if (
            model_config.is_multimodal_model
            and hasattr(model_config.hf_config, "tie_word_embeddings")
            and not hasattr(hf_config.get_text_config(), "tie_word_embeddings")
        ):
            # In Transformers v5, tie_word_embeddings belongs to the config of the class
            # that can see both layers to be tied. For example:
            #
            # SomeVLModel:
            #   self.language_model = SomeLanguageModel()
            #   self.vision_model = SomeVisionModel()
            #
            # SomeVLModelForMultimodalLM:
            #   self.model = SomeVLModel()
            #   self.lm_head = nn.Linear()
            #
            # Therefore, tie_word_embeddings is defined in SomeVLModelForMultimodalLM's
            # config and is not present in SomeVLModel's config. In vLLM, the lm_head
            # belongs to the language_model, so we must ensure that tie_word_embeddings
            # is set in the language_model's config.
            tie_word_embeddings = model_config.hf_config.tie_word_embeddings
            hf_config.get_text_config().tie_word_embeddings = tie_word_embeddings

573
        model_config.hf_config = hf_config
574
        model_config.model_arch_config = model_config.get_model_arch_config()
575
576
577

        return replace(self, model_config=model_config)

578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
    def _set_config_default(self, config_obj: Any, key: str, value: Any) -> None:
        """Set config attribute to default if not already set by user.

        Args:
            config_obj: Configuration object to update.
            key: Attribute name.
            value: Default value (static or callable).
        """
        if getattr(config_obj, key) is None:
            # Some config values are known before initialization and are
            # hard coded.
            # Other values depend on the user given configuration, so they are
            # implemented with lambda functions and decided at run time.
            setattr(config_obj, key, value(self) if callable(value) else value)

    def _apply_optimization_level_defaults(self, defaults: dict[str, Any]) -> None:
        """Apply optimization level defaults using self as root.

        Recursively applies values from defaults into nested config objects.
        Only fields present in defaults are overwritten.

        If the user configuration does not specify a value for a default field
        and if the default field is still None after all user selections are
Jiayi Yan's avatar
Jiayi Yan committed
601
        applied, then default values will be applied to the field. User specified
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        fields will not be overridden by the default.

        Args:
            defaults: Dictionary of default values to apply.
        """

        def apply_recursive(config_obj: Any, config_defaults: dict[str, Any]) -> None:
            """Recursively apply defaults to config_obj, using self as root."""
            for key, value in config_defaults.items():
                if not hasattr(config_obj, key):
                    continue

                current = getattr(config_obj, key)
                if isinstance(value, dict) and is_dataclass(current):
                    apply_recursive(current, value)
                else:
                    self._set_config_default(config_obj, key, value)

        apply_recursive(self, defaults)

622
623
624
625
626
627
    def _post_init_kv_transfer_config(self) -> None:
        """Update KVTransferConfig based on top-level configs in VllmConfig.

        Right now, this function reads the offloading settings from
        CacheConfig and configures the KVTransferConfig accordingly.
        """
628
629
        # KV offloading is only activated when kv_offloading_size is set.
        if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
630
631
            return

632
633
        kv_offloading_backend = self.cache_config.kv_offloading_backend

634
635
636
637
638
639
640
641
642
643
644
        # If no KVTransferConfig is provided, create a default one.
        if self.kv_transfer_config is None:
            self.kv_transfer_config = KVTransferConfig()
        num_kv_ranks = (
            self.parallel_config.tensor_parallel_size
            * self.parallel_config.pipeline_parallel_size
        )

        if kv_offloading_backend == "native":
            self.kv_transfer_config.kv_connector = "OffloadingConnector"
            self.kv_transfer_config.kv_connector_extra_config.update(
645
                {"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
646
647
648
649
650
651
652
653
654
655
656
657
            )
        elif kv_offloading_backend == "lmcache":
            self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
            kv_gb_per_rank = kv_offloading_size / num_kv_ranks
            self.kv_transfer_config.kv_connector_extra_config = {
                "lmcache.local_cpu": True,
                "lmcache.max_local_cpu_size": kv_gb_per_rank,
            }

        # This is the same for all backends
        self.kv_transfer_config.kv_role = "kv_both"

658
    def __post_init__(self):
659
        """Verify configs are valid & consistent with each other."""
660

661
662
663
        # To give each torch profile run a unique instance name.
        self.instance_id = f"{time.time_ns()}"

664
665
666
667
668
        if self.performance_mode != "balanced":
            logger.info_once(
                "Performance mode set to '%s'.", self.performance_mode, scope="local"
            )

669
670
671
672
        self.try_verify_and_update_config()

        if self.model_config is not None:
            self.model_config.verify_with_parallel_config(self.parallel_config)
673
            self.model_config.verify_dual_chunk_attention_config(self.load_config)
674

675
676
            self.parallel_config.is_moe_model = self.model_config.is_moe

677
678
679
680
681
        if self.lora_config is not None:
            self.lora_config.verify_with_model_config(self.model_config)

        if self.quant_config is None and self.model_config is not None:
            self.quant_config = VllmConfig._get_quantization_config(
682
683
                self.model_config, self.load_config
            )
684

685
686
        from vllm.v1.executor.abstract import Executor

687
        executor_backend = self.parallel_config.distributed_executor_backend
688
689
        executor_class = Executor.get_class(self)
        executor_supports_async_sched = executor_class.supports_async_scheduling()
690
691
692

        if self.scheduler_config.async_scheduling:
            # Async scheduling explicitly enabled, hard fail any incompatibilities.
693
694
            # Currently, async scheduling only support eagle speculative
            # decoding.
695
            if self.speculative_config is not None:
696
697
                if (
                    self.speculative_config.method not in get_args(EagleModelTypes)
698
                    and self.speculative_config.method not in get_args(NgramGPUTypes)
699
700
                    and self.speculative_config.method != "draft_model"
                ):
701
702
                    raise ValueError(
                        "Currently, async scheduling is only supported "
703
704
                        "with EAGLE/MTP/Draft Model/NGram GPU kind of "
                        "speculative decoding"
705
706
707
                    )
                if self.speculative_config.disable_padded_drafter_batch:
                    raise ValueError(
708
709
                        "Async scheduling is not compatible with "
                        "disable_padded_drafter_batch=True."
710
                    )
711
712
            if not executor_supports_async_sched:
                raise ValueError(
713
                    f"`{executor_backend}` does not support async scheduling yet."
714
715
716
                )
        elif self.scheduler_config.async_scheduling is None:
            # Enable async scheduling unless there is an incompatible option.
717
            if (
718
719
                self.speculative_config is not None
                and self.speculative_config.method not in get_args(EagleModelTypes)
720
                and self.speculative_config.method not in get_args(NgramGPUTypes)
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
            ):
                logger.warning_once(
                    "Async scheduling not supported with %s-based "
                    "speculative decoding and will be disabled.",
                    self.speculative_config.method,
                    scope="local",
                )
                self.scheduler_config.async_scheduling = False
            elif (
                self.speculative_config is not None
                and self.speculative_config.disable_padded_drafter_batch
            ):
                logger.warning_once(
                    "Async scheduling is not compatible with "
                    "disable_padded_drafter_batch=True and will be disabled.",
                    scope="local",
                )
738
                self.scheduler_config.async_scheduling = False
739
            elif not executor_supports_async_sched:
740
                logger.warning_once(
741
                    "Async scheduling will be disabled because it is not supported "
742
                    "with the `%s` distributed executor backend. ",
743
                    executor_backend,
744
                    scope="local",
745
746
747
748
749
                )
                self.scheduler_config.async_scheduling = False
            else:
                self.scheduler_config.async_scheduling = True

750
751
752
753
754
        logger.info_once(
            "Asynchronous scheduling is %s.",
            "enabled" if self.scheduler_config.async_scheduling else "disabled",
        )

755
756
        if self.parallel_config.disable_nccl_for_dp_synchronization is None:
            if self.scheduler_config.async_scheduling:
757
758
759
760
761
762
763
764
                if self.parallel_config.data_parallel_size > 1 and (
                    self.model_config is None or self.model_config.is_moe
                ):
                    logger.info_once(
                        "Disabling NCCL for DP synchronization "
                        "when using async scheduling.",
                        scope="local",
                    )
765
766
767
768
                self.parallel_config.disable_nccl_for_dp_synchronization = True
            else:
                self.parallel_config.disable_nccl_for_dp_synchronization = False

769
770
771
772
773
774
775
776
777
778
779
780
781
        if (
            self.speculative_config is not None
            and self.scheduler_config.async_scheduling
            and self.model_config is not None
            and not self.model_config.disable_cascade_attn
        ):
            logger.warning_once(
                "Disabling cascade attention (not yet compatible with "
                "async speculative decoding).",
                scope="local",
            )
            self.model_config.disable_cascade_attn = True

782
783
784
785
786
787
788
789
790
791
792
        if (
            self.model_config is not None
            and self.model_config.multimodal_config is not None
            and self.model_config.multimodal_config.mm_tensor_ipc == "torch_shm"
            and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
        ):
            raise ValueError(
                "torch_shm is known to fail without "
                "VLLM_WORKER_MULTIPROC_METHOD set to spawn"
            )

793
        from vllm.platforms import current_platform
794
795
796

        if (
            self.model_config is not None
797
            and self.scheduler_config.enable_chunked_prefill
798
799
800
            and self.model_config.dtype == torch.float32
            and current_platform.get_device_capability() == (7, 5)
        ):
801
802
803
            logger.warning_once(
                "Turing devices tensor cores do not support float32 matmul. "
                "To workaround this limitation, vLLM will set 'ieee' input "
804
805
                "precision for chunked prefill triton kernels."
            )
806

807
808
809
810
811
812
813
        if self.model_config is not None and self.model_config.enforce_eager:
            logger.warning(
                "Enforce eager set, disabling torch.compile and CUDAGraphs. "
                "This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none"
            )
            self.compilation_config.mode = CompilationMode.NONE
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
814
815
816
817
818
819

        if self.compilation_config.backend == "eager" or (
            self.compilation_config.mode is not None
            and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
        ):
            logger.warning(
820
821
822
                "Inductor compilation was disabled by user settings, "
                "optimizations settings that are only active during "
                "inductor compilation will be ignored."
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
            )

        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
            if "-quant_fp8" not in custom_ops:
                custom_ops.append("+quant_fp8")

842
843
        current_platform.apply_config_platform_defaults(self)

844
        if self.compilation_config.mode is None:
845
            if self.optimization_level > OptimizationLevel.O0:
846
                self.compilation_config.mode = CompilationMode.VLLM_COMPILE
847
            else:
848
                self.compilation_config.mode = CompilationMode.NONE
849
850
851
852

        if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
            if (
                self.compilation_config.backend == "inductor"
853
                and self.compilation_config.mode != CompilationMode.NONE
854
855
856
857
            ):
                self.compilation_config.custom_ops.append("none")
            else:
                self.compilation_config.custom_ops.append("all")
858

859
860
        default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
        self._apply_optimization_level_defaults(default_config)
861
862
863
864
865
        if self.kernel_config.enable_flashinfer_autotune is None:
            raise ValueError(
                "KernelConfig.enable_flashinfer_autotune must be set after applying "
                "optimization level defaults."
            )
866

867
        if (
868
            self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
869
870
871
872
873
874
875
876
877
878
            and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
        ):
            logger.info(
                "Cudagraph mode %s is not compatible with compilation mode %s."
                "Overriding to NONE.",
                self.compilation_config.cudagraph_mode,
                self.compilation_config.mode,
            )
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

879
880
        # async tp is built on top of sequence parallelism
        # and requires it to be enabled.
881
882
883
        if self.compilation_config.pass_config.fuse_gemm_comms:
            self.compilation_config.pass_config.enable_sp = True
        if self.compilation_config.pass_config.enable_sp:
884
885
886
887
            if self.parallel_config.tensor_parallel_size == 1:
                logger.warning("Sequence Parallelism requires TP>1, disabling")
                self.compilation_config.pass_config.enable_sp = False
                self.compilation_config.pass_config.fuse_gemm_comms = False
888
889
            else:
                # Compute SP threshold early; disable if None (model too
890
                # small for SP to be beneficial).
891
892
893
894
895
896
897
898
                pass_config = self.compilation_config.pass_config
                if pass_config.sp_min_token_num is None:
                    from vllm.compilation.passes.fusion.sequence_parallelism import (
                        get_sequence_parallelism_threshold,
                    )

                    tp_size = self.parallel_config.tensor_parallel_size
                    hidden_size = self.model_config.get_hidden_size()
899
                    element_size = self.model_config.dtype.itemsize  # type: ignore[union-attr]
900
901
902
                    pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
                        hidden_size, tp_size, element_size
                    )
903

904
905
906
907
908
909
910
911
912
                if pass_config.sp_min_token_num is None:
                    logger.warning(
                        "Model hidden_size too small for the SP "
                        "threshold heuristic, disabling. To force SP, "
                        "set pass_config.sp_min_token_num manually."
                    )
                    self.compilation_config.pass_config.enable_sp = False
                    self.compilation_config.pass_config.fuse_gemm_comms = False

913
914
915
916
917
918
919
        from vllm.utils.torch_utils import HAS_OPAQUE_TYPE

        if HAS_OPAQUE_TYPE:
            # On torch >= 2.11 the hoisted OpaqueObject approach supersedes
            # fast_moe_cold_start, so force it off.
            self.compilation_config.fast_moe_cold_start = False
        elif self.compilation_config.fast_moe_cold_start is None:
920
921
922
923
924
925
926
            # resolve default behavior: try to be as safe as possible
            # this config is unsafe if any spec decoding draft model has a MOE.
            # We'll conservatively turn it off if we see spec decoding.
            self.compilation_config.fast_moe_cold_start = (
                self.speculative_config is None
            )

927
928
        self._set_max_num_scheduled_tokens()

929
        if current_platform.support_static_graph_mode():
930
            # if cudagraph_mode has full cudagraphs, we need to check support
931
932
933
934
935
            if model_config := self.model_config:
                if (
                    self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                    and model_config.pooler_config is not None
                ):
936
                    logger.warning_once(
937
                        "Pooling models do not support full cudagraphs. "
938
939
940
                        "Overriding cudagraph_mode to PIECEWISE."
                    )
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
941
942
943
944
945
946
947
948
949
950
951
952
                elif (
                    model_config.is_encoder_decoder
                    and self.compilation_config.cudagraph_mode
                    not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
                ):
                    logger.info_once(
                        "Encoder-decoder models do not support %s. "
                        "Overriding cudagraph_mode to FULL_DECODE_ONLY.",
                        self.compilation_config.cudagraph_mode.name,
                    )
                    self.compilation_config.cudagraph_mode = (
                        CUDAGraphMode.FULL_DECODE_ONLY
953
                    )
954

955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
            # Check if KV connector requires PIECEWISE mode for CUDA graphs
            if (
                self.kv_transfer_config is not None
                and self.kv_transfer_config.is_kv_transfer_instance
                and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
            ):
                # Lazy import to avoid circular dependencies
                from vllm.distributed.kv_transfer.kv_connector.factory import (
                    KVConnectorFactory,
                )

                connector_cls = KVConnectorFactory.get_connector_class(
                    self.kv_transfer_config
                )
                if connector_cls.requires_piecewise_for_cudagraph(
                    self.kv_transfer_config.kv_connector_extra_config
                ):
                    logger.warning_once(
                        "KV connector %s requires PIECEWISE CUDA graph mode "
                        "due to layerwise async operations that cannot be "
                        "captured in CUDA graphs. "
                        "Overriding cudagraph_mode from %s to PIECEWISE.",
                        connector_cls.__name__,
                        self.compilation_config.cudagraph_mode.name,
                    )
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

982
            # disable cudagraph when enforce eager execution
983
            if self.model_config is not None and self.model_config.enforce_eager:
984
985
                logger.info("Cudagraph is disabled under eager mode")
                self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
986
987
988
                # override related settings when enforce eager
                self.compilation_config.max_cudagraph_capture_size = 0
                self.compilation_config.cudagraph_capture_sizes = []
989
            else:
990
991
992
993
994
995
996
                self.compilation_config.cudagraph_num_of_warmups = 1

            self._set_cudagraph_sizes()
        else:
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

        if self.cache_config.kv_sharing_fast_prefill:
997
998
999
1000
            if (
                self.speculative_config is not None
                and self.speculative_config.use_eagle()
            ):
1001
                raise ValueError(
1002
1003
1004
                    "Fast prefill optimization for KV sharing is not "
                    "compatible with EAGLE as EAGLE requires correct logits "
                    "for all tokens while fast prefill gives incorrect logits "
1005
1006
                    "for prompt tokens."
                )
1007
1008
1009

            logger.warning_once(
                "--kv-sharing-fast-prefill requires changes on model side for "
1010
                "correctness and to realize prefill savings."
1011
            )
1012

1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
        if (
            self.model_config
            and self.model_config.architecture == "WhisperForConditionalGeneration"
            and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
        ):
            logger.warning(
                "Whisper is known to have issues with "
                "forked workers. If startup is hanging, "
                "try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
                "to 'spawn'."
1023
            )
1024

1025
1026
1027
1028
1029
        if (
            self.kv_events_config is not None
            and self.kv_events_config.enable_kv_cache_events
            and not self.cache_config.enable_prefix_caching
        ):
1030
            logger.warning(
1031
                "KV cache events are on, but prefix caching is not enabled. "
1032
1033
1034
1035
1036
1037
1038
1039
                "Use --enable-prefix-caching to enable."
            )
        if (
            self.kv_events_config is not None
            and self.kv_events_config.publisher != "null"
            and not self.kv_events_config.enable_kv_cache_events
        ):
            logger.warning(
1040
1041
1042
                "KV cache events are disabled, "
                "but the scheduler is configured to publish them. "
                "Modify KVEventsConfig.enable_kv_cache_events "
1043
1044
                "to True to enable."
            )
1045
1046
        current_platform.check_and_update_config(self)

1047
1048
1049
1050
        # Re-compute compile ranges after platform-specific config updates
        # (e.g., XPU may lower max_num_batched_tokens when MLA is enabled)
        self._set_compile_ranges()

1051
        # Do this after all the updates to compilation_config.mode
1052
1053
1054
1055
1056
        effective_dp_size = (
            self.parallel_config.data_parallel_size
            if self.model_config is None or self.model_config.is_moe
            else 1
        )
1057
1058
        self.compilation_config.set_splitting_ops_for_v1(
            all2all_backend=self.parallel_config.all2all_backend,
1059
            data_parallel_size=effective_dp_size,
1060
        )
1061

1062
        if self.compilation_config.pass_config.enable_sp:
1063
1064
1065
1066
1067
            # With pipeline parallelism or dynamo partitioning,
            # native rms norm tracing errors due to incorrect residual shape.
            # Use custom rms norm to unblock. In the future,
            # the pass will operate on higher-level IR to avoid the issue.
            # TODO: https://github.com/vllm-project/vllm/issues/27894
1068
1069
1070
1071
1072
1073
1074
            if self.compilation_config.mode != CompilationMode.VLLM_COMPILE:
                logger.warning(
                    "Sequence parallelism is enabled, but running in wrong "
                    "vllm compile mode: %s.",
                    self.compilation_config.mode,
                )

1075
1076
            is_fullgraph = (
                self.compilation_config.use_inductor_graph_partition
1077
                or len(self.compilation_config.splitting_ops or []) == 0
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
            )
            if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
                if "-rms_norm" not in self.compilation_config.custom_ops:
                    self.compilation_config.custom_ops.append("+rms_norm")
                else:
                    regime = (
                        "Dynamo partition"
                        if not is_fullgraph
                        else "pipeline parallelism"
                    )
                    logger.warning_once(
1089
                        "Sequence parallelism not supported with "
1090
1091
1092
1093
1094
                        "native rms_norm when using %s, "
                        "this will likely lead to an error.",
                        regime,
                    )

1095
        # final check of cudagraph mode after all possible updates
1096
        if current_platform.is_cuda_alike():
1097
1098
1099
1100
            if (
                self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                and self.model_config is not None
                and not self.model_config.disable_cascade_attn
1101
                and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()  # noqa: E501
1102
            ):
1103
1104
1105
                logger.warning_once(
                    "No piecewise cudagraph for executing cascade attention."
                    " Will fall back to eager execution if a batch runs "
1106
                    "into cascade attentions."
1107
1108
1109
                )

            if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
1110
1111
                assert self.compilation_config.mode == CompilationMode.VLLM_COMPILE, (
                    "Compilation mode should be CompilationMode.VLLM_COMPILE "
1112
                    "when cudagraph_mode piecewise cudagraphs is used, "
1113
                    f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
1114
                )
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
        from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant

        if (
            self.model_config
            and vllm_is_batch_invariant()
            and not self.model_config.disable_cascade_attn
        ):
            self.model_config.disable_cascade_attn = True
            logger.warning_once(
                "Disabling cascade attention when VLLM_BATCH_INVARIANT is enabled.",
                scope="local",
            )
1127

1128
        if self.parallel_config.use_ubatching:
1129
            a2a_backend = self.parallel_config.all2all_backend
1130
1131
1132
1133
            assert a2a_backend in [
                "deepep_low_latency",
                "deepep_high_throughput",
            ], (
1134
1135
                "Microbatching currently only supports the deepep_low_latency and "
                f"deepep_high_throughput all2all backend. {a2a_backend} is not "
1136
1137
1138
                "supported. To fix use --all2all-backend=deepep_low_latency or "
                "--all2all-backend=deepep_high_throughput and install the DeepEP"
                " kernels."
1139
            )
1140
1141
1142

            if not self.model_config.disable_cascade_attn:
                self.model_config.disable_cascade_attn = True
1143
                logger.warning_once("Disabling cascade attention when DBO is enabled.")
1144
1145
1146
1147

        if not self.instance_id:
            self.instance_id = random_uuid()[:5]

1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
        # Hybrid KV cache manager (HMA) runtime rules:
        # - Explicit enable (--no-disable-kv-cache-manager): error if runtime
        #   disables it
        # - No preference: auto-disable for unsupported features (e.g. kv connector)
        # - Explicit disable (--disable-kv-cache-manager): always respect it
        need_disable_hybrid_kv_cache_manager = False
        # logger should only print warning message for hybrid models. As we
        # can't know whether the model is hybrid or not now, so we don't log
        # warning message here and will log it later.
        if not current_platform.support_hybrid_kv_cache():
            # Hybrid KV cache manager is not supported on non-GPU platforms.
            need_disable_hybrid_kv_cache_manager = True
        if self.kv_events_config is not None:
            # Hybrid KV cache manager is not compatible with KV events.
            need_disable_hybrid_kv_cache_manager = True
        if (
            self.model_config is not None
            and self.model_config.attention_chunk_size is not None
        ):
            if (
                self.speculative_config is not None
                and self.speculative_config.use_eagle()
            ):
                # Hybrid KV cache manager is not yet supported with chunked
                # local attention + eagle.
                need_disable_hybrid_kv_cache_manager = True
            elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
                logger.warning(
                    "There is a latency regression when using chunked local"
                    " attention with the hybrid KV cache manager. Disabling"
                    " it, by default. To enable it, set the environment "
                    "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
                )
                # Hybrid KV cache manager is not yet supported with chunked
                # local attention.
                need_disable_hybrid_kv_cache_manager = True

        if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
            # Default to disable HMA, but only if the user didn't express a preference.
1187
            if self.kv_transfer_config is not None:
1188
1189
                # NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
                need_disable_hybrid_kv_cache_manager = True
1190
1191
1192
1193
1194
1195
1196
                logger.warning(
                    "Turning off hybrid kv cache manager because "
                    "`--kv-transfer-config` is set. This will reduce the "
                    "performance of vLLM on LLMs with sliding window attention "
                    "or Mamba attention. If you are a developer of kv connector"
                    ", please consider supporting hybrid kv cache manager for "
                    "your connector by making sure your connector is a subclass"
1197
1198
                    " of `SupportsHMA` defined in kv_connector/v1/base.py and"
                    " use --no-disable-hybrid-kv-cache-manager to start vLLM."
1199
                )
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
            self.scheduler_config.disable_hybrid_kv_cache_manager = (
                need_disable_hybrid_kv_cache_manager
            )
        elif (
            self.scheduler_config.disable_hybrid_kv_cache_manager is False
            and need_disable_hybrid_kv_cache_manager
        ):
            raise ValueError(
                "Hybrid KV cache manager was explicitly enabled but is not "
                "supported in this configuration. Consider omitting the "
                "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide"
                " automatically."
            )

        if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
            # Default to enable HMA if not explicitly disabled by user or logic above.
            self.scheduler_config.disable_hybrid_kv_cache_manager = False
1217
1218

        if self.compilation_config.debug_dump_path:
1219
            self.compilation_config.debug_dump_path = (
1220
                self.compilation_config.debug_dump_path.absolute().expanduser()
1221
            )
1222
1223
1224
1225
1226
        if envs.VLLM_DEBUG_DUMP_PATH is not None:
            env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
            if self.compilation_config.debug_dump_path:
                logger.warning(
                    "Config-specified debug dump path is overridden"
1227
1228
1229
                    " by VLLM_DEBUG_DUMP_PATH to %s",
                    env_path,
                )
1230
1231
            self.compilation_config.debug_dump_path = env_path

1232
        def has_blocked_weights():  # type: ignore[no-redef]
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
1246
            if "-quant_fp8" not in custom_ops:
1247
1248
                custom_ops.append("+quant_fp8")

1249
1250
1251
        # Handle the KV connector configs
        self._post_init_kv_transfer_config()

1252
1253
1254
        # Log the custom passes that are enabled
        self.compilation_config.pass_config.log_enabled_passes()

1255
    def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
1256
1257
1258
        # remove the sizes that not multiple of tp_size when
        # enable sequence parallelism
        removed_sizes = [
1259
1260
            size
            for size in possible_sizes
1261
1262
1263
1264
1265
1266
            if size % self.parallel_config.tensor_parallel_size != 0
        ]
        if removed_sizes:
            logger.warning(
                "Batch sizes %s are removed because they are not "
                "multiple of tp_size %d when "
1267
1268
1269
1270
                "sequence parallelism is enabled",
                removed_sizes,
                self.parallel_config.tensor_parallel_size,
            )
1271
1272

        return [
1273
1274
            size
            for size in possible_sizes
1275
1276
1277
            if size % self.parallel_config.tensor_parallel_size == 0
        ]

1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
    def _set_max_num_scheduled_tokens(self):
        """
        In most cases, the scheduler may schedule a batch with as many tokens as the
        worker is configured to handle. However for some speculative decoding methods,
        the drafter model may insert additional slots into the batch when drafting.
        To account for this, we need to decrease the max_num_scheduled_tokens by an
        upper bound on the number of slots that can be added.
        """
        if self.speculative_config is not None:
            scheduled_token_delta = (
                self.speculative_config.max_num_new_slots_for_drafting
                * self.scheduler_config.max_num_seqs
            )
            max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
            if self.scheduler_config.max_num_scheduled_tokens is None:
                self.scheduler_config.max_num_scheduled_tokens = (
                    max_num_batched_tokens - scheduled_token_delta
                )

            max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens
            if max_num_batched_tokens < max_num_scheduled_tokens + (
                self.speculative_config.max_num_new_slots_for_drafting
                * self.scheduler_config.max_num_seqs
            ):
                raise ValueError(
                    f"VllmConfig received max_num_scheduled_tokens but it does not have"
                    " enough slots to support the speculative decoding settings."
                    f" It should be greater by at least {scheduled_token_delta}, but"
                    f" got {max_num_batched_tokens=} and {max_num_scheduled_tokens=}."
                )

1309
1310
1311
1312
1313
1314
1315
    def _set_cudagraph_sizes(self):
        """
        vLLM defines the default candidate list of batch sizes for CUDA graph
        capture as:

        ```python
        max_graph_size = min(max_num_seqs * 2, 512)
1316
1317
        # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
        # up to max_graph_size
1318
        cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
1319
            range(256, max_graph_size + 1, 16))
1320
1321

        In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
1322
        will be the final sizes to capture cudagraph (in ascending order).
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348

        These sizes are used to capture and reuse CUDA graphs for
        performance-critical paths (e.g., decoding). Capturing enables
        significantly faster kernel dispatch by avoiding Python overhead. The
        list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
        most GPUs), which controls the total allowed number of tokens in a
        batch. Since each sequence may have a variable number of tokens, the
        maximum usable batch size will depend on actual sequence lengths.

        Example:
            With `max_num_batched_tokens = 8192`, and typical sequences
            averaging ~32 tokens, most practical batch sizes fall below 256.
            However, the system will still allow capture sizes up to 512 if
            shape and memory permit.

        Note:
            If users explicitly specify cudagraph capture sizes in the
            compilation config, those will override this default logic.
            At runtime:

            - If batch size <= one of the `cudagraph_capture_sizes`, the closest
            padded CUDA graph will be used.
            - If batch size > largest `cudagraph_capture_sizes`, cudagraph will
            not be used.
        """

1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
        if (
            self.model_config is not None
            and not self.model_config.enforce_eager
            and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
            # determine the initial max_cudagraph_capture_size
            max_cudagraph_capture_size = (
                self.compilation_config.max_cudagraph_capture_size
            )
            if max_cudagraph_capture_size is None:
1359
1360
1361
1362
1363
1364
                decode_query_len = 1
                if (
                    self.speculative_config
                    and self.speculative_config.num_speculative_tokens
                ):
                    decode_query_len += self.speculative_config.num_speculative_tokens
1365
                max_cudagraph_capture_size = min(
1366
                    self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
1367
                )
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
            max_num_tokens = self.scheduler_config.max_num_batched_tokens
            max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)

            assert max_cudagraph_capture_size >= 1, (
                "Maximum cudagraph size should be greater than or equal to 1 "
                "when using cuda graph."
            )

            # determine the cudagraph_capture_sizes
            if self.compilation_config.cudagraph_capture_sizes is not None:
                assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
                    "cudagraph_capture_sizes should contain at least one element "
                    "when using cuda graph."
                )
                # de-duplicate the sizes provided by the config
                dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
1384
1385
1386
                cudagraph_capture_sizes = [
                    i for i in dedup_sizes if i <= max_num_tokens
                ]
1387
1388
                # sort to make sure the sizes are in ascending order
                cudagraph_capture_sizes.sort()
1389
            else:
1390
1391
1392
1393
1394
1395
1396
1397
1398
                if self.performance_mode == "interactivity":
                    # Fine-grained CUDA graphs at small batch sizes
                    # for minimal padding overhead
                    interactivity_max = min(max_cudagraph_capture_size, 32)
                    cudagraph_capture_sizes = list(range(1, interactivity_max + 1))
                else:
                    cudagraph_capture_sizes = [
                        i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
                    ]
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
                if max_cudagraph_capture_size >= 8:
                    # Step size 8 for small batch sizes, up to 256(not included)
                    cudagraph_capture_sizes += list(
                        range(8, min(max_cudagraph_capture_size + 1, 256), 8)
                    )
                if max_cudagraph_capture_size >= 256:
                    # Step size 16 for larger batch sizes
                    cudagraph_capture_sizes += list(
                        range(256, max_cudagraph_capture_size + 1, 16)
                    )
1409
1410
                # de-duplicate and sort the sizes
                cudagraph_capture_sizes = sorted(set(cudagraph_capture_sizes))
1411

1412
1413
            if (
                self.parallel_config.tensor_parallel_size > 1
1414
                and self.compilation_config.pass_config.enable_sp
1415
            ):
1416
1417
                cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
                    cudagraph_capture_sizes
1418
                )
1419

1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
            # user-specific compilation_config.max_cudagraph_capture_size get
            # truncated to valid_max_size when they are inconsistent.
            valid_max_size = (
                cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
            )
            if (
                self.compilation_config.max_cudagraph_capture_size is not None
                and self.compilation_config.max_cudagraph_capture_size != valid_max_size
            ):
                # raise error only when both two flags are user-specified
                # and they are inconsistent with each other
                if self.compilation_config.cudagraph_capture_sizes is not None:
                    raise ValueError(
                        "customized max_cudagraph_capture_size"
                        f"(={self.compilation_config.max_cudagraph_capture_size}) "
                        "should be consistent with the max value of "
                        f"cudagraph_capture_sizes(={valid_max_size})"
                    )

                logger.warning(
                    "Truncating max_cudagraph_capture_size to %d",
                    valid_max_size,
                )
            # always set the final max_cudagraph_capture_size
            self.compilation_config.max_cudagraph_capture_size = valid_max_size

            if self.compilation_config.cudagraph_capture_sizes is not None and len(
                cudagraph_capture_sizes
            ) < len(self.compilation_config.cudagraph_capture_sizes):
                # If users have specified capture sizes, we only need to
                # compare the lens before and after modification since the modified
                # list is only the subset of the original list.
                logger.warning(
                    (
                        "cudagraph_capture_sizes specified in compilation_config"
                        " %s is overridden by config %s"
                    ),
                    self.compilation_config.cudagraph_capture_sizes,
                    cudagraph_capture_sizes,
                )
            # always write back the final sizes
            self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes

        else:
            # no cudagraph in use
            self.compilation_config.max_cudagraph_capture_size = 0
            self.compilation_config.cudagraph_capture_sizes = []

        # complete the remaining process.
        self.compilation_config.post_init_cudagraph_sizes()
1470

1471
1472
1473
1474
1475
    def _set_compile_ranges(self):
        """
        Set the compile ranges for the compilation config.
        """
        compilation_config = self.compilation_config
1476
        computed_compile_ranges_endpoints = []
1477

1478
1479
1480
        # The upper bound of the compile ranges is the max_num_batched_tokens.
        compile_range_end = self.scheduler_config.max_num_batched_tokens
        if compile_range_end is not None:
1481
            computed_compile_ranges_endpoints.append(compile_range_end)
1482
1483
1484
1485
1486
1487
1488
1489

        # Add the compile ranges for flashinfer
        if compilation_config.pass_config.fuse_allreduce_rms:
            tp_size = self.parallel_config.tensor_parallel_size
            max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
            if max_size is not None:
                max_token_num = max_size // (
                    self.model_config.get_hidden_size()
1490
                    * self.model_config.dtype.itemsize  # type: ignore[union-attr]
1491
                )
1492
                if compile_range_end is not None and max_token_num < compile_range_end:
1493
                    computed_compile_ranges_endpoints.append(max_token_num)
1494
1495
1496
1497
1498
1499
                else:
                    logger.debug(
                        "Max num batched tokens below allreduce-rms fusion threshold, "
                        "allreduce-rms fusion will be enabled for all num_tokens."
                    )

1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
        # Add the compile ranges for sequence parallelism
        if compilation_config.pass_config.enable_sp:
            pass_config = compilation_config.pass_config

            # Calculate min_token_num if not explicitly provided
            # User override works regardless of hidden_size
            if pass_config.sp_min_token_num is None:
                from vllm.compilation.passes.fusion.sequence_parallelism import (
                    get_sequence_parallelism_threshold,
                )

                tp_size = self.parallel_config.tensor_parallel_size
                hidden_size = self.model_config.get_hidden_size()
1513
                element_size = self.model_config.dtype.itemsize  # type: ignore[union-attr]
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
                pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
                    hidden_size, tp_size, element_size
                )

            min_token_num = pass_config.sp_min_token_num
            max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
            if min_token_num is not None and (
                max_num_batched_tokens is not None
                and min_token_num < max_num_batched_tokens
                and min_token_num > 1
            ):
1525
                # Add endpoint at min_token_num - 1 to ensure SP applies
1526
1527
                # starting from min_token_num
                # This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
1528
                computed_compile_ranges_endpoints.append(min_token_num - 1)
1529

1530
1531
1532
1533
1534
1535
        if compilation_config.pass_config.fuse_rope_kvcache:
            max_token_num = (
                compilation_config.pass_config.rope_kvcache_fusion_max_token_num
            )
            if max_token_num is not None:
                if compile_range_end is not None and max_token_num < compile_range_end:
1536
                    computed_compile_ranges_endpoints.append(max_token_num)
1537
1538
1539
1540
1541
1542
1543
                else:
                    logger.debug(
                        "Max num batched tokens below rope+kvcache fusion threshold, "
                        "rope+kvcache fusion enabled for num_tokens <= %d.",
                        compile_range_end,
                    )

1544
1545
        if compilation_config.compile_ranges_endpoints is not None:
            for x in compilation_config.compile_ranges_endpoints:
1546
                assert isinstance(x, int)
1547
                assert x > 0, f"Invalid compile range endpoint: {x}"
1548
                if compile_range_end is not None and x < compile_range_end and x > 1:
1549
1550
1551
                    computed_compile_ranges_endpoints.append(x)
        compilation_config.compile_ranges_endpoints = sorted(
            computed_compile_ranges_endpoints
1552
1553
        )

1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
    def try_verify_and_update_config(self):
        if self.model_config is None:
            return

        # Avoid running try_verify_and_update_config multiple times
        if getattr(self.model_config, "config_updated", False):
            return
        self.model_config.config_updated = True

        architecture = self.model_config.architecture
        if architecture is None:
            return

        from vllm.model_executor.models.config import (
1568
1569
1570
1571
            MODELS_CONFIG_MAP,
            HybridAttentionMambaModelConfig,
        )

1572
1573
1574
1575
1576
1577
1578
1579
1580
        cls = MODELS_CONFIG_MAP.get(architecture, None)
        if cls is not None:
            cls.verify_and_update_config(self)

        if self.model_config.is_hybrid:
            HybridAttentionMambaModelConfig.verify_and_update_config(self)

        if self.model_config.convert_type == "classify":
            # Maybe convert ForCausalLM into ForSequenceClassification model.
1581
1582
            from vllm.model_executor.models.adapters import SequenceClassificationConfig

1583
1584
1585
            SequenceClassificationConfig.verify_and_update_config(self)

        if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
1586
1587
            self.model_config.model_weights
        ):
1588
            if self.load_config.load_format == "auto":
1589
1590
1591
1592
                logger.info(
                    "Detected Run:ai model config. "
                    "Overriding `load_format` to 'runai_streamer'"
                )
1593
                self.load_config.load_format = "runai_streamer"
1594
1595
1596
1597
            elif self.load_config.load_format not in (
                "runai_streamer",
                "runai_streamer_sharded",
            ):
1598
                raise ValueError(
1599
1600
1601
                    f"To load a model from object storage (S3/GCS/Azure), "
                    f"'load_format' must be 'runai_streamer' or "
                    f"'runai_streamer_sharded', "
1602
1603
1604
                    f"but got '{self.load_config.load_format}'. "
                    f"Model: {self.model_config.model}"
                )
1605

1606
    def compile_debug_dump_path(self) -> Path | None:
1607
        """Returns a rank-aware path for dumping
1608
1609
1610
1611
1612
        torch.compile debug information.
        """
        if self.compilation_config.debug_dump_path is None:
            return None
        tp_rank = self.parallel_config.rank
1613
1614
        dp_rank = self.parallel_config.data_parallel_index
        append_path = f"rank_{tp_rank}_dp_{dp_rank}"
1615
1616
1617
1618
1619
1620
1621
        path = self.compilation_config.debug_dump_path / append_path
        return path

    def __str__(self):
        return (
            f"model={self.model_config.model!r}, "
            f"speculative_config={self.speculative_config!r}, "
1622
1623
1624
            f"tokenizer={self.model_config.tokenizer!r}, "
            f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
            f"tokenizer_mode={self.model_config.tokenizer_mode}, "
1625
            f"revision={self.model_config.revision}, "
1626
            f"tokenizer_revision={self.model_config.tokenizer_revision}, "
1627
1628
1629
1630
1631
1632
1633
1634
            f"trust_remote_code={self.model_config.trust_remote_code}, "
            f"dtype={self.model_config.dtype}, "
            f"max_seq_len={self.model_config.max_model_len}, "
            f"download_dir={self.load_config.download_dir!r}, "
            f"load_format={self.load_config.load_format}, "
            f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, "  # noqa
            f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, "  # noqa
            f"data_parallel_size={self.parallel_config.data_parallel_size}, "  # noqa
1635
1636
            f"decode_context_parallel_size={self.parallel_config.decode_context_parallel_size}, "  # noqa
            f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, "  # noqa
1637
1638
1639
            f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, "  # noqa
            f"quantization={self.model_config.quantization}, "
            f"enforce_eager={self.model_config.enforce_eager}, "
1640
            f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, "  # noqa
1641
1642
1643
1644
1645
1646
1647
            f"kv_cache_dtype={self.cache_config.cache_dtype}, "
            f"device_config={self.device_config.device}, "
            f"structured_outputs_config={self.structured_outputs_config!r}, "
            f"observability_config={self.observability_config!r}, "
            f"seed={self.model_config.seed}, "
            f"served_model_name={self.model_config.served_model_name}, "
            f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
1648
            f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, "  # noqa
1649
            f"pooler_config={self.model_config.pooler_config!r}, "
1650
1651
            f"compilation_config={self.compilation_config!r}"
        )
1652

1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
    def validate_block_size(self) -> None:
        """Validate block_size against DCP and mamba constraints.

        Called after Platform.update_block_size_for_backend() has
        finalised block_size.
        """
        block_size = self.cache_config.block_size

        # DCP interleave-size compatibility
        if self.parallel_config.decode_context_parallel_size > 1:
            if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
                self.parallel_config.cp_kv_cache_interleave_size
                != self.parallel_config.dcp_kv_cache_interleave_size
            ):
                self.parallel_config.cp_kv_cache_interleave_size = (
                    self.parallel_config.dcp_kv_cache_interleave_size
                )
                logger.warning_once(
                    "cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
                    "_interleave_size. And dcp-kv-cache-interleave-size will be "
                    "deprecated when PCP is fully supported."
                )
            assert (
                self.parallel_config.cp_kv_cache_interleave_size <= block_size
                and block_size % self.parallel_config.cp_kv_cache_interleave_size == 0
            ), (
                f"Block_size({block_size}) should be greater "
                "than or equal to and divisible by cp_kv_cache_interleave_size "
                f"({self.parallel_config.cp_kv_cache_interleave_size})."
            )

        # Mamba cache align-mode constraints
        if self.cache_config.mamba_cache_mode == "align":
            assert block_size <= self.scheduler_config.max_num_batched_tokens, (
                "In Mamba cache align mode, block_size "
                f"({block_size}) must be <= "
                "max_num_batched_tokens "
                f"({self.scheduler_config.max_num_batched_tokens})."
            )
            if self.scheduler_config.long_prefill_token_threshold > 0:
                assert self.scheduler_config.long_prefill_token_threshold >= block_size
            assert not self.scheduler_config.disable_chunked_mm_input, (
                "Chunked MM input is required because we need the flexibility "
                "to schedule a multiple of block_size tokens even if they are "
                "in the middle of a mm input"
            )

1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
    @model_validator(mode="after")
    def validate_mamba_block_size(self) -> "VllmConfig":
        if self.model_config is None:
            return self
        mamba_block_size_is_set = (
            self.cache_config.mamba_block_size is not None
            and self.cache_config.mamba_block_size != self.model_config.max_model_len
        )
        if mamba_block_size_is_set and not self.cache_config.enable_prefix_caching:
            raise ValueError(
                "--mamba-block-size can only be set with --enable-prefix-caching"
            )
        return self

1714

1715
1716
_current_vllm_config: VllmConfig | None = None
_current_prefix: str | None = None
1717
1718
1719


@contextmanager
1720
def set_current_vllm_config(
1721
    vllm_config: VllmConfig, check_compile=False, prefix: str | None = None
1722
):
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
    """
    Temporarily set the current vLLM config.
    Used during model initialization.
    We save the current vLLM config in a global variable,
    so that all modules can access it, e.g. custom ops
    can access the vLLM config to determine how to dispatch.
    """
    global _current_vllm_config, _current_prefix
    old_vllm_config = _current_vllm_config
    old_prefix = _current_prefix
    from vllm.compilation.counter import compilation_counter
1734

1735
1736
    num_models_seen = compilation_counter.num_models_seen
    try:
1737
1738
1739
1740
1741
        # Clear the compilation config cache when context changes.
        # This is needed since the old config may have been accessed
        # and cached before the new config is set.
        get_cached_compilation_config.cache_clear()

1742
1743
1744
1745
1746
1747
1748
1749
1750
        _current_vllm_config = vllm_config
        _current_prefix = prefix
        yield
    except Exception:
        raise
    else:
        if check_compile:
            vllm_config.compilation_config.custom_op_log_check()

1751
1752
        if (
            check_compile
1753
            and vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
1754
1755
            and compilation_counter.num_models_seen == num_models_seen
        ):
1756
1757
1758
1759
1760
1761
1762
1763
1764
            # If the model supports compilation,
            # compilation_counter.num_models_seen should be increased
            # by at least 1.
            # If it is not increased, it means the model does not support
            # compilation (does not have @support_torch_compile decorator).
            logger.warning(
                "`torch.compile` is turned on, but the model %s"
                " does not support it. Please open an issue on GitHub"
                " if you want it to be supported.",
1765
1766
                vllm_config.model_config.model,
            )
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
    finally:
        _current_vllm_config = old_vllm_config
        _current_prefix = old_prefix
        # Clear the compilation config cache when context changes
        get_cached_compilation_config.cache_clear()


@lru_cache(maxsize=1)
def get_cached_compilation_config():
    """Cache config to avoid repeated calls to get_current_vllm_config()"""
    return get_current_vllm_config().compilation_config


def get_current_vllm_config() -> VllmConfig:
    if _current_vllm_config is None:
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
        raise AssertionError(
            "Current vLLM config is not set. This typically means "
            "get_current_vllm_config() was called outside of a "
            "set_current_vllm_config() context, or a CustomOp was instantiated "
            "at module import time or model forward time when config is not set. "
            "For tests that directly test custom ops/modules, use the "
            "'default_vllm_config' pytest fixture from tests/conftest.py."
        )
    return _current_vllm_config


def get_current_vllm_config_or_none() -> VllmConfig | None:
1794
1795
1796
1797
1798
1799
1800
    return _current_vllm_config


T = TypeVar("T")


def get_layers_from_vllm_config(
1801
1802
    vllm_config: VllmConfig,
    layer_type: type[T],
1803
    layer_names: list[str] | None = None,
1804
) -> dict[str, T]:
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
    """
    Get layers from the vLLM config.

    Args:
        vllm_config: The vLLM config.
        layer_type: The type of the layer to get.
        layer_names: The names of the layers to get. If None, return all layers.
    """

    if layer_names is None:
1815
        layer_names = list(vllm_config.compilation_config.static_forward_context.keys())
1816
1817
1818
1819
1820
1821

    forward_context = vllm_config.compilation_config.static_forward_context

    return {
        layer_name: forward_context[layer_name]
        for layer_name in layer_names
1822
1823
        if layer_name in forward_context
        and isinstance(forward_context[layer_name], layer_type)
1824
    }