vllm.py 68.3 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
import getpass
6
7
import json
import os
8
9
import tempfile
import threading
10
import time
11
from contextlib import contextmanager
12
from dataclasses import is_dataclass
13
from datetime import datetime
14
from enum import IntEnum
15
16
from functools import lru_cache
from pathlib import Path
17
from typing import TYPE_CHECKING, Any, TypeVar, get_args
18
19

import torch
20
from pydantic import ConfigDict, Field, model_validator
21
22

import vllm.envs as envs
23
from vllm.logger import enable_trace_function_call, init_logger
24
25
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
26
from vllm.utils.hashing import safe_hash
27

28
from .attention import AttentionConfig
29
from .cache import CacheConfig
30
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
31
from .device import DeviceConfig
32
from .ec_transfer import ECTransferConfig
33
from .kernel import KernelConfig
34
35
36
37
38
39
40
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .parallel import ParallelConfig
41
from .profiler import ProfilerConfig
42
from .scheduler import SchedulerConfig
43
from .speculative import EagleModelTypes, SpeculativeConfig
44
from .structured_outputs import StructuredOutputsConfig
45
from .utils import SupportsHash, config, replace
46
from .weight_transfer import WeightTransferConfig
47
48
49
50

if TYPE_CHECKING:
    from transformers import PretrainedConfig

51
    from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
52
    from vllm.v1.kv_cache_interface import KVCacheConfig
53
54
55
56
57
else:
    PretrainedConfig = Any

    QuantizationConfig = Any

58
59
    KVCacheConfig = Any

60
61
62
logger = init_logger(__name__)


63
64
65
66
67
68
69
class OptimizationLevel(IntEnum):
    """Optimization level enum."""

    O0 = 0
    """O0 : No optimization. no compilation, no cudagraphs, no other
    optimization, just starting up immediately"""
    O1 = 1
70
    """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    cudagraphs"""
    O2 = 2
    """O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
    O3 = 3
    """O3: Currently the same as -O2s."""


IS_QUANTIZED = False
IS_DENSE = False
# The optimizations that depend on these properties currently set to False
# in all cases.
# if model_config is not None:
#     IS_QUANTIZED = lambda c: c.model_config.is_quantized()
#     IS_DENSE = lambda c: not c.model_config.is_model_moe()
# See https://github.com/vllm-project/vllm/issues/25689.


88
89
90
91
def enable_norm_fusion(cfg: "VllmConfig") -> bool:
    """Enable if either RMS norm or quant FP8 custom op is active;
    otherwise Inductor handles fusion."""

92
93
94
95
96
    return cfg.compilation_config.is_custom_op_enabled(
        "rms_norm"
    ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")


97
def enable_act_fusion(cfg: "VllmConfig") -> bool:
98
99
100
101
102
103
104
105
106
107
    """
    Enable if either SiLU+Mul or quant FP8 custom op is active;
    otherwise Inductor handles fusion.
    Also enable for FP4 models as FP4 quant is always custom so Inductor cannot fuse it.
    """
    return (
        cfg.compilation_config.is_custom_op_enabled("silu_and_mul")
        or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
        or (cfg.model_config is not None and cfg.model_config.is_nvfp4_quantized())
    )
108
109


110
def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
111
    """Enable if TP > 1 and Hopper/Blackwell and flashinfer installed."""
112
113
114
115
116
117
118
    from vllm.platforms import current_platform
    from vllm.utils.flashinfer import has_flashinfer

    return (
        cfg.parallel_config.tensor_parallel_size > 1
        and current_platform.is_cuda()
        and has_flashinfer()
119
120
121
122
123
124
125
        and (
            current_platform.is_device_capability(100)
            or current_platform.is_device_capability(90)
        )
        # tp-dp combination broken:
        # https://github.com/vllm-project/vllm/issues/34458
        and cfg.parallel_config.data_parallel_size == 1
126
127
128
    )


129
130
131
132
133
134
135
136
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
    """Enable if using AITER RMSNorm and AITER Triton GEMMs
    and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""

    return (
        envs.VLLM_ROCM_USE_AITER
        and envs.VLLM_ROCM_USE_AITER_RMSNORM
        and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
137
        and cfg.model_config is not None
138
139
140
141
        and cfg.model_config.get_hidden_size() == 2880
    )


142
143
144
OPTIMIZATION_LEVEL_00 = {
    "compilation_config": {
        "pass_config": {
145
146
147
148
149
150
            "fuse_norm_quant": False,
            "fuse_act_quant": False,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": False,
            "enable_sp": False,
            "fuse_gemm_comms": False,
151
            "fuse_act_padding": False,
152
153
154
155
        },
        "cudagraph_mode": CUDAGraphMode.NONE,
        "use_inductor_graph_partition": False,
    },
156
157
158
    "kernel_config": {
        "enable_flashinfer_autotune": False,
    },
159
160
161
162
}
OPTIMIZATION_LEVEL_01 = {
    "compilation_config": {
        "pass_config": {
163
164
165
166
167
168
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": False,
            "enable_sp": False,
            "fuse_gemm_comms": False,
169
            "fuse_act_padding": enable_norm_pad_fusion,
170
171
172
173
        },
        "cudagraph_mode": CUDAGraphMode.PIECEWISE,
        "use_inductor_graph_partition": False,
    },
174
175
176
    "kernel_config": {
        "enable_flashinfer_autotune": True,
    },
177
178
179
180
}
OPTIMIZATION_LEVEL_02 = {
    "compilation_config": {
        "pass_config": {
181
182
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
183
            "fuse_allreduce_rms": enable_allreduce_rms_fusion,
184
185
186
            "fuse_attn_quant": IS_QUANTIZED,
            "enable_sp": IS_DENSE,
            "fuse_gemm_comms": IS_DENSE,
187
            "fuse_act_padding": enable_norm_pad_fusion,
188
189
190
191
        },
        "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
        "use_inductor_graph_partition": False,
    },
192
193
194
    "kernel_config": {
        "enable_flashinfer_autotune": True,
    },
195
196
197
198
}
OPTIMIZATION_LEVEL_03 = {
    "compilation_config": {
        "pass_config": {
199
200
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
201
            "fuse_allreduce_rms": enable_allreduce_rms_fusion,
202
203
204
            "fuse_attn_quant": IS_QUANTIZED,
            "enable_sp": IS_DENSE,
            "fuse_gemm_comms": IS_DENSE,
205
            "fuse_act_padding": enable_norm_pad_fusion,
206
207
208
209
        },
        "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
        "use_inductor_graph_partition": False,
    },
210
211
212
    "kernel_config": {
        "enable_flashinfer_autotune": True,
    },
213
214
215
216
217
218
219
220
221
222
}

OPTIMIZATION_LEVEL_TO_CONFIG = {
    OptimizationLevel.O0: OPTIMIZATION_LEVEL_00,
    OptimizationLevel.O1: OPTIMIZATION_LEVEL_01,
    OptimizationLevel.O2: OPTIMIZATION_LEVEL_02,
    OptimizationLevel.O3: OPTIMIZATION_LEVEL_03,
}


223
@config(config=ConfigDict(arbitrary_types_allowed=True))
224
225
226
227
228
229
230
class VllmConfig:
    """Dataclass which contains all vllm-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    # TODO: use default_factory once default constructing ModelConfig doesn't
    # try to download a model
231
    model_config: ModelConfig = Field(default=None)
232
    """Model configuration."""
233
    cache_config: CacheConfig = Field(default_factory=CacheConfig)
234
    """Cache configuration."""
235
    parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
236
    """Parallel configuration."""
237
238
239
    scheduler_config: SchedulerConfig = Field(
        default_factory=SchedulerConfig.default_factory,
    )
240
    """Scheduler configuration."""
241
    device_config: DeviceConfig = Field(default_factory=DeviceConfig)
242
    """Device configuration."""
243
    load_config: LoadConfig = Field(default_factory=LoadConfig)
244
    """Load configuration."""
245
246
    attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
    """Attention configuration."""
247
248
    kernel_config: KernelConfig = Field(default_factory=KernelConfig)
    """Kernel configuration."""
249
    lora_config: LoRAConfig | None = None
250
    """LoRA configuration."""
251
    speculative_config: SpeculativeConfig | None = None
252
    """Speculative decoding configuration."""
253
    structured_outputs_config: StructuredOutputsConfig = Field(
254
255
        default_factory=StructuredOutputsConfig
    )
256
    """Structured outputs configuration."""
257
258
259
    observability_config: ObservabilityConfig = Field(
        default_factory=ObservabilityConfig
    )
260
    """Observability configuration."""
261
    quant_config: QuantizationConfig | None = None
262
    """Quantization configuration."""
263
    compilation_config: CompilationConfig = Field(default_factory=CompilationConfig)
264
265
    """`torch.compile` and cudagraph capture configuration for the model.

266
267
    As a shorthand, one can append compilation arguments via
    -cc.parameter=argument such as `-cc.mode=3` (same as `-cc='{"mode":3}'`).
268
269

    You can specify the full compilation config like so:
270
    `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
271
    """
272
273
    profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
    """Profiling configuration."""
274
    kv_transfer_config: KVTransferConfig | None = None
275
    """The configurations for distributed KV cache transfer."""
276
    kv_events_config: KVEventsConfig | None = None
277
    """The configurations for event publishing."""
278
279
    ec_transfer_config: ECTransferConfig | None = None
    """The configurations for distributed EC cache transfer."""
280
281
282
    # some opaque config, only used to provide additional information
    # for the hash computation, mainly used for testing, debugging or out of
    # tree config registration.
283
    additional_config: dict | SupportsHash = Field(default_factory=dict)
284
285
286
287
288
    """Additional config for specified platform. Different platforms may
    support different configs. Make sure the configs are valid for the platform
    you are using. Contents must be hashable."""
    instance_id: str = ""
    """The ID of the vLLM instance."""
289
290
291
    optimization_level: OptimizationLevel = OptimizationLevel.O2
    """The optimization level. These levels trade startup time cost for
    performance, with -O0 having the best startup time and -O3 having the best
292
    performance. -O2 is used by default. See OptimizationLevel for full
293
    description."""
294

295
296
297
    weight_transfer_config: WeightTransferConfig | None = None
    """The configurations for weight transfer during RL training."""

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []

        # summarize vllm config
        vllm_factors: list[Any] = []
        from vllm import __version__
315

316
317
318
        vllm_factors.append(__version__)
        if self.model_config:
            vllm_factors.append(self.model_config.compute_hash())
319
320
321
322
323
324
            if (
                self.compilation_config
                and getattr(self.compilation_config, "compile_mm_encoder", False)
                and self.model_config.multimodal_config
            ):
                vllm_factors.append(self.model_config.multimodal_config.compute_hash())
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        else:
            vllm_factors.append("None")
        if self.cache_config:
            vllm_factors.append(self.cache_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.parallel_config:
            vllm_factors.append(self.parallel_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.scheduler_config:
            vllm_factors.append(self.scheduler_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.device_config:
            vllm_factors.append(self.device_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.load_config:
            vllm_factors.append(self.load_config.compute_hash())
        else:
            vllm_factors.append("None")
347
348
349
350
        if self.attention_config:
            vllm_factors.append(self.attention_config.compute_hash())
        else:
            vllm_factors.append("None")
351
352
353
354
355
356
357
358
359
360
        if self.lora_config:
            vllm_factors.append(self.lora_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.speculative_config:
            vllm_factors.append(self.speculative_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.structured_outputs_config:
            vllm_factors.append(self.structured_outputs_config.compute_hash())
361
362
        if self.profiler_config:
            vllm_factors.append(self.profiler_config.compute_hash())
363
364
        else:
            vllm_factors.append("None")
365
        vllm_factors.append(self.observability_config.compute_hash())
366
367
368
369
370
371
372
373
374
375
        if self.quant_config:
            pass  # should be captured by model_config.quantization
        if self.compilation_config:
            vllm_factors.append(self.compilation_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.kv_transfer_config:
            vllm_factors.append(self.kv_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
376
377
378
379
        if self.ec_transfer_config:
            vllm_factors.append(self.ec_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
380
381
        if self.additional_config:
            if isinstance(additional_config := self.additional_config, dict):
382
                additional_config_hash = safe_hash(
383
384
385
386
387
388
389
390
391
392
                    json.dumps(additional_config, sort_keys=True).encode(),
                    usedforsecurity=False,
                ).hexdigest()
            else:
                additional_config_hash = additional_config.compute_hash()
            vllm_factors.append(additional_config_hash)
        else:
            vllm_factors.append("None")
        factors.append(vllm_factors)

393
394
395
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
            :10
        ]
396
397
        return hash_str

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    @property
    def needs_dp_coordinator(self) -> bool:
        """
        Determine if the DPCoordinator process is needed.

        The DPCoordinator is needed in two cases:
        1. For MoE models with DP > 1: to handle wave coordination
           (even in external LB mode, since wave coordination runs in the coordinator)
        2. For non-MoE models in internal/hybrid LB mode: to collect and publish
           queue stats for load balancing across DP ranks

        Returns:
            True if DPCoordinator process is needed, False otherwise.
        """

        # For non-MoE models, only need coordinator in internal/hybrid LB mode
        # (for stats collection).
        return self.parallel_config.data_parallel_size > 1 and (
            self.model_config is None
            or self.model_config.is_moe
            or not self.parallel_config.data_parallel_external_lb
        )

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    def enable_trace_function_call_for_thread(self) -> None:
        """
        Set up function tracing for the current thread,
        if enabled via the `VLLM_TRACE_FUNCTION` environment variable.
        """
        if envs.VLLM_TRACE_FUNCTION:
            tmp_dir = tempfile.gettempdir()
            # add username to tmp_dir to avoid permission issues
            tmp_dir = os.path.join(tmp_dir, getpass.getuser())
            filename = (
                f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                f"_thread_{threading.get_ident()}_at_{datetime.now()}.log"
            ).replace(" ", "_")
            log_path = os.path.join(
                tmp_dir,
                "vllm",
                f"vllm-instance-{self.instance_id}",
                filename,
            )
            os.makedirs(os.path.dirname(log_path), exist_ok=True)
            enable_trace_function_call(log_path)

443
444
    @staticmethod
    def _get_quantization_config(
445
        model_config: ModelConfig, load_config: LoadConfig
446
    ) -> QuantizationConfig | None:
447
448
        """Get the quantization config."""
        from vllm.platforms import current_platform
449

450
        if model_config.quantization is not None:
451
452
            from vllm.model_executor.model_loader.weight_utils import get_quant_config

453
454
455
456
457
458
459
460
461
462
            quant_config = get_quant_config(model_config, load_config)
            capability_tuple = current_platform.get_device_capability()

            if capability_tuple is not None:
                capability = capability_tuple.to_int()
                if capability < quant_config.get_min_capability():
                    raise ValueError(
                        f"The quantization method {model_config.quantization} "
                        "is not supported for the current GPU. Minimum "
                        f"capability: {quant_config.get_min_capability()}. "
463
464
                        f"Current capability: {capability}."
                    )
465
466
467
468
469
            supported_dtypes = quant_config.get_supported_act_dtypes()
            if model_config.dtype not in supported_dtypes:
                raise ValueError(
                    f"{model_config.dtype} is not supported for quantization "
                    f"method {model_config.quantization}. Supported dtypes: "
470
471
                    f"{supported_dtypes}"
                )
472
473
474
475
476
477
            quant_config.maybe_update_config(model_config.model)
            return quant_config
        return None

    @staticmethod
    def get_quantization_config(
478
        model_config: ModelConfig, load_config: LoadConfig
479
    ) -> QuantizationConfig | None:
480
481
482
483
        import copy

        # For some reason, the _ version of this modifies the model_config
        # object, so using deepcopy to avoid this problem.
484
485
486
        return VllmConfig._get_quantization_config(
            copy.deepcopy(model_config), load_config
        )
487
488
489
490

    def with_hf_config(
        self,
        hf_config: PretrainedConfig,
491
        architectures: list[str] | None = None,
492
493
494
495
496
497
    ) -> "VllmConfig":
        if architectures is not None:
            hf_config = copy.deepcopy(hf_config)
            hf_config.architectures = architectures

        model_config = copy.deepcopy(self.model_config)
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521

        if (
            model_config.is_multimodal_model
            and hasattr(model_config.hf_config, "tie_word_embeddings")
            and not hasattr(hf_config.get_text_config(), "tie_word_embeddings")
        ):
            # In Transformers v5, tie_word_embeddings belongs to the config of the class
            # that can see both layers to be tied. For example:
            #
            # SomeVLModel:
            #   self.language_model = SomeLanguageModel()
            #   self.vision_model = SomeVisionModel()
            #
            # SomeVLModelForMultimodalLM:
            #   self.model = SomeVLModel()
            #   self.lm_head = nn.Linear()
            #
            # Therefore, tie_word_embeddings is defined in SomeVLModelForMultimodalLM's
            # config and is not present in SomeVLModel's config. In vLLM, the lm_head
            # belongs to the language_model, so we must ensure that tie_word_embeddings
            # is set in the language_model's config.
            tie_word_embeddings = model_config.hf_config.tie_word_embeddings
            hf_config.get_text_config().tie_word_embeddings = tie_word_embeddings

522
        model_config.hf_config = hf_config
523
        model_config.model_arch_config = model_config.get_model_arch_config()
524
525
526

        return replace(self, model_config=model_config)

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    def _set_config_default(self, config_obj: Any, key: str, value: Any) -> None:
        """Set config attribute to default if not already set by user.

        Args:
            config_obj: Configuration object to update.
            key: Attribute name.
            value: Default value (static or callable).
        """
        if getattr(config_obj, key) is None:
            # Some config values are known before initialization and are
            # hard coded.
            # Other values depend on the user given configuration, so they are
            # implemented with lambda functions and decided at run time.
            setattr(config_obj, key, value(self) if callable(value) else value)

    def _apply_optimization_level_defaults(self, defaults: dict[str, Any]) -> None:
        """Apply optimization level defaults using self as root.

        Recursively applies values from defaults into nested config objects.
        Only fields present in defaults are overwritten.

        If the user configuration does not specify a value for a default field
        and if the default field is still None after all user selections are
        applied, then default values will be applied to the field. User speciied
        fields will not be overridden by the default.

        Args:
            defaults: Dictionary of default values to apply.
        """

        def apply_recursive(config_obj: Any, config_defaults: dict[str, Any]) -> None:
            """Recursively apply defaults to config_obj, using self as root."""
            for key, value in config_defaults.items():
                if not hasattr(config_obj, key):
                    continue

                current = getattr(config_obj, key)
                if isinstance(value, dict) and is_dataclass(current):
                    apply_recursive(current, value)
                else:
                    self._set_config_default(config_obj, key, value)

        apply_recursive(self, defaults)

571
572
573
574
575
576
    def _post_init_kv_transfer_config(self) -> None:
        """Update KVTransferConfig based on top-level configs in VllmConfig.

        Right now, this function reads the offloading settings from
        CacheConfig and configures the KVTransferConfig accordingly.
        """
577
578
        # KV offloading is only activated when kv_offloading_size is set.
        if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
579
580
            return

581
582
        kv_offloading_backend = self.cache_config.kv_offloading_backend

583
584
585
586
587
588
589
590
591
592
593
        # If no KVTransferConfig is provided, create a default one.
        if self.kv_transfer_config is None:
            self.kv_transfer_config = KVTransferConfig()
        num_kv_ranks = (
            self.parallel_config.tensor_parallel_size
            * self.parallel_config.pipeline_parallel_size
        )

        if kv_offloading_backend == "native":
            self.kv_transfer_config.kv_connector = "OffloadingConnector"
            self.kv_transfer_config.kv_connector_extra_config.update(
594
                {"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
595
596
597
598
599
600
601
602
603
604
605
606
            )
        elif kv_offloading_backend == "lmcache":
            self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
            kv_gb_per_rank = kv_offloading_size / num_kv_ranks
            self.kv_transfer_config.kv_connector_extra_config = {
                "lmcache.local_cpu": True,
                "lmcache.max_local_cpu_size": kv_gb_per_rank,
            }

        # This is the same for all backends
        self.kv_transfer_config.kv_role = "kv_both"

607
    def __post_init__(self):
608
        """Verify configs are valid & consistent with each other."""
609

610
611
612
        # To give each torch profile run a unique instance name.
        self.instance_id = f"{time.time_ns()}"

613
614
615
616
        self.try_verify_and_update_config()

        if self.model_config is not None:
            self.model_config.verify_with_parallel_config(self.parallel_config)
617
            self.model_config.verify_dual_chunk_attention_config(self.load_config)
618

619
620
            self.parallel_config.is_moe_model = self.model_config.is_moe

621
622
623
624
625
626
627
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config is not None:
            self.lora_config.verify_with_model_config(self.model_config)

        if self.quant_config is None and self.model_config is not None:
            self.quant_config = VllmConfig._get_quantization_config(
628
629
                self.model_config, self.load_config
            )
630

631
632
633
634
635
636
637
638
639
        executor_backend = self.parallel_config.distributed_executor_backend
        executor_supports_async_sched = executor_backend in (
            "mp",
            "uni",
            "external_launcher",
        )

        if self.scheduler_config.async_scheduling:
            # Async scheduling explicitly enabled, hard fail any incompatibilities.
640
641
            # Currently, async scheduling only support eagle speculative
            # decoding.
642
            if self.speculative_config is not None:
643
644
645
646
                if (
                    self.speculative_config.method not in get_args(EagleModelTypes)
                    and self.speculative_config.method != "draft_model"
                ):
647
648
                    raise ValueError(
                        "Currently, async scheduling is only supported "
649
                        "with EAGLE/MTP/Draft Model kind of speculative decoding."
650
651
652
                    )
                if self.speculative_config.disable_padded_drafter_batch:
                    raise ValueError(
653
654
                        "Async scheduling is not compatible with "
                        "disable_padded_drafter_batch=True."
655
                    )
656
657
658
659
660
661
662
663
            if not executor_supports_async_sched:
                raise ValueError(
                    "Currently, async scheduling only supports `mp`, `uni`, or "
                    "`external_launcher` distributed executor backend, but you chose "
                    f"`{executor_backend}`."
                )
        elif self.scheduler_config.async_scheduling is None:
            # Enable async scheduling unless there is an incompatible option.
664
            if (
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
                self.speculative_config is not None
                and self.speculative_config.method not in get_args(EagleModelTypes)
            ):
                logger.warning_once(
                    "Async scheduling not supported with %s-based "
                    "speculative decoding and will be disabled.",
                    self.speculative_config.method,
                    scope="local",
                )
                self.scheduler_config.async_scheduling = False
            elif (
                self.speculative_config is not None
                and self.speculative_config.disable_padded_drafter_batch
            ):
                logger.warning_once(
                    "Async scheduling is not compatible with "
                    "disable_padded_drafter_batch=True and will be disabled.",
                    scope="local",
                )
684
                self.scheduler_config.async_scheduling = False
685
            elif not executor_supports_async_sched:
686
                logger.warning_once(
687
688
689
690
                    "Async scheduling will be disabled because it is not supported "
                    "with the `%s` distributed executor backend (only `mp`, `uni`, and "
                    "`external_launcher` are supported).",
                    executor_backend,
691
                    scope="local",
692
693
694
695
696
                )
                self.scheduler_config.async_scheduling = False
            else:
                self.scheduler_config.async_scheduling = True

697
698
699
700
701
        logger.info_once(
            "Asynchronous scheduling is %s.",
            "enabled" if self.scheduler_config.async_scheduling else "disabled",
        )

702
703
        if self.parallel_config.disable_nccl_for_dp_synchronization is None:
            if self.scheduler_config.async_scheduling:
704
705
706
707
708
709
710
711
                if self.parallel_config.data_parallel_size > 1 and (
                    self.model_config is None or self.model_config.is_moe
                ):
                    logger.info_once(
                        "Disabling NCCL for DP synchronization "
                        "when using async scheduling.",
                        scope="local",
                    )
712
713
714
715
                self.parallel_config.disable_nccl_for_dp_synchronization = True
            else:
                self.parallel_config.disable_nccl_for_dp_synchronization = False

716
        from vllm.platforms import current_platform
717
718
719

        if (
            self.model_config is not None
720
            and self.scheduler_config.enable_chunked_prefill
721
722
723
            and self.model_config.dtype == torch.float32
            and current_platform.get_device_capability() == (7, 5)
        ):
724
725
726
            logger.warning_once(
                "Turing devices tensor cores do not support float32 matmul. "
                "To workaround this limitation, vLLM will set 'ieee' input "
727
728
                "precision for chunked prefill triton kernels."
            )
729

730
731
732
733
734
735
736
        if self.model_config is not None and self.model_config.enforce_eager:
            logger.warning(
                "Enforce eager set, disabling torch.compile and CUDAGraphs. "
                "This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none"
            )
            self.compilation_config.mode = CompilationMode.NONE
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
737
738
739
740
741
742

        if self.compilation_config.backend == "eager" or (
            self.compilation_config.mode is not None
            and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
        ):
            logger.warning(
743
744
745
                "Inductor compilation was disabled by user settings, "
                "optimizations settings that are only active during "
                "inductor compilation will be ignored."
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
            )

        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
            if "-quant_fp8" not in custom_ops:
                custom_ops.append("+quant_fp8")

765
        if self.compilation_config.mode is None:
766
            if self.optimization_level > OptimizationLevel.O0:
767
                self.compilation_config.mode = CompilationMode.VLLM_COMPILE
768
            else:
769
                self.compilation_config.mode = CompilationMode.NONE
770
771
772
773

        if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
            if (
                self.compilation_config.backend == "inductor"
774
                and self.compilation_config.mode != CompilationMode.NONE
775
776
777
778
            ):
                self.compilation_config.custom_ops.append("none")
            else:
                self.compilation_config.custom_ops.append("all")
779

780
781
        default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
        self._apply_optimization_level_defaults(default_config)
782
783
784
785
786
        if self.kernel_config.enable_flashinfer_autotune is None:
            raise ValueError(
                "KernelConfig.enable_flashinfer_autotune must be set after applying "
                "optimization level defaults."
            )
787

788
        if (
789
            self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
790
791
792
793
794
795
796
797
798
799
            and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
        ):
            logger.info(
                "Cudagraph mode %s is not compatible with compilation mode %s."
                "Overriding to NONE.",
                self.compilation_config.cudagraph_mode,
                self.compilation_config.mode,
            )
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

800
801
        # async tp is built on top of sequence parallelism
        # and requires it to be enabled.
802
803
804
        if self.compilation_config.pass_config.fuse_gemm_comms:
            self.compilation_config.pass_config.enable_sp = True
        if self.compilation_config.pass_config.enable_sp:
805
806
807
808
809
810
            if self.parallel_config.tensor_parallel_size == 1:
                logger.warning("Sequence Parallelism requires TP>1, disabling")
                self.compilation_config.pass_config.enable_sp = False
                self.compilation_config.pass_config.fuse_gemm_comms = False

            elif "-rms_norm" in self.compilation_config.custom_ops:
811
812
813
814
815
                logger.warning(
                    "RMS norm force disabled, sequence parallelism might break"
                )
            else:
                self.compilation_config.custom_ops.append("+rms_norm")
816

817
818
819
820
821
822
823
824
        if self.compilation_config.fast_moe_cold_start is None:
            # resolve default behavior: try to be as safe as possible
            # this config is unsafe if any spec decoding draft model has a MOE.
            # We'll conservatively turn it off if we see spec decoding.
            self.compilation_config.fast_moe_cold_start = (
                self.speculative_config is None
            )

825
        if current_platform.support_static_graph_mode():
826
            # if cudagraph_mode has full cudagraphs, we need to check support
827
828
829
830
831
            if model_config := self.model_config:
                if (
                    self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                    and model_config.pooler_config is not None
                ):
832
                    logger.warning_once(
833
                        "Pooling models do not support full cudagraphs. "
834
835
836
                        "Overriding cudagraph_mode to PIECEWISE."
                    )
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
837
838
839
840
841
842
843
844
845
846
847
848
                elif (
                    model_config.is_encoder_decoder
                    and self.compilation_config.cudagraph_mode
                    not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
                ):
                    logger.info_once(
                        "Encoder-decoder models do not support %s. "
                        "Overriding cudagraph_mode to FULL_DECODE_ONLY.",
                        self.compilation_config.cudagraph_mode.name,
                    )
                    self.compilation_config.cudagraph_mode = (
                        CUDAGraphMode.FULL_DECODE_ONLY
849
                    )
850
851

            # disable cudagraph when enforce eager execution
852
            if self.model_config is not None and self.model_config.enforce_eager:
853
854
                logger.info("Cudagraph is disabled under eager mode")
                self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
855
856
857
                # override related settings when enforce eager
                self.compilation_config.max_cudagraph_capture_size = 0
                self.compilation_config.cudagraph_capture_sizes = []
858
            else:
859
860
861
862
863
864
865
                self.compilation_config.cudagraph_num_of_warmups = 1

            self._set_cudagraph_sizes()
        else:
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

        if self.cache_config.kv_sharing_fast_prefill:
866
867
868
869
            if (
                self.speculative_config is not None
                and self.speculative_config.use_eagle()
            ):
870
                raise ValueError(
871
872
873
                    "Fast prefill optimization for KV sharing is not "
                    "compatible with EAGLE as EAGLE requires correct logits "
                    "for all tokens while fast prefill gives incorrect logits "
874
875
                    "for prompt tokens."
                )
876
877
878

            logger.warning_once(
                "--kv-sharing-fast-prefill requires changes on model side for "
879
                "correctness and to realize prefill savings."
880
            )
881
882
        # TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
        self._set_compile_ranges()
883

884
885
886
887
888
889
890
891
892
893
        if (
            self.model_config
            and self.model_config.architecture == "WhisperForConditionalGeneration"
            and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
        ):
            logger.warning(
                "Whisper is known to have issues with "
                "forked workers. If startup is hanging, "
                "try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
                "to 'spawn'."
894
            )
895

896
897
898
899
900
        if (
            self.kv_events_config is not None
            and self.kv_events_config.enable_kv_cache_events
            and not self.cache_config.enable_prefix_caching
        ):
901
            logger.warning(
902
                "KV cache events are on, but prefix caching is not enabled. "
903
904
905
906
907
908
909
910
                "Use --enable-prefix-caching to enable."
            )
        if (
            self.kv_events_config is not None
            and self.kv_events_config.publisher != "null"
            and not self.kv_events_config.enable_kv_cache_events
        ):
            logger.warning(
911
912
913
                "KV cache events are disabled, "
                "but the scheduler is configured to publish them. "
                "Modify KVEventsConfig.enable_kv_cache_events "
914
915
                "to True to enable."
            )
916
917
        current_platform.check_and_update_config(self)

918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
        # If DCP, ensure the block size is right.
        if self.parallel_config.decode_context_parallel_size > 1:
            if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
                self.parallel_config.cp_kv_cache_interleave_size
                != self.parallel_config.dcp_kv_cache_interleave_size
            ):
                self.parallel_config.cp_kv_cache_interleave_size = (
                    self.parallel_config.dcp_kv_cache_interleave_size
                )
                logger.warning_once(
                    "cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
                    "_interleave_size. And dcp-kv-cache-interleave-size will be "
                    "deprecated when PCP is fully supported."
                )
            assert (
                self.parallel_config.cp_kv_cache_interleave_size
                <= self.cache_config.block_size
                and self.cache_config.block_size
                % self.parallel_config.cp_kv_cache_interleave_size
                == 0
            ), (
                f"Block_size({self.cache_config.block_size}) should be greater "
                "than or equal to and divisible by cp_kv_cache_interleave_size "
                f"({self.parallel_config.cp_kv_cache_interleave_size})."
            )

944
        # Do this after all the updates to compilation_config.mode
945
946
947
948
949
        effective_dp_size = (
            self.parallel_config.data_parallel_size
            if self.model_config is None or self.model_config.is_moe
            else 1
        )
950
951
        self.compilation_config.set_splitting_ops_for_v1(
            all2all_backend=self.parallel_config.all2all_backend,
952
            data_parallel_size=effective_dp_size,
953
        )
954

955
        if self.compilation_config.pass_config.enable_sp:
956
957
958
959
960
            # With pipeline parallelism or dynamo partitioning,
            # native rms norm tracing errors due to incorrect residual shape.
            # Use custom rms norm to unblock. In the future,
            # the pass will operate on higher-level IR to avoid the issue.
            # TODO: https://github.com/vllm-project/vllm/issues/27894
961
962
963
964
965
966
967
            if self.compilation_config.mode != CompilationMode.VLLM_COMPILE:
                logger.warning(
                    "Sequence parallelism is enabled, but running in wrong "
                    "vllm compile mode: %s.",
                    self.compilation_config.mode,
                )

968
969
970
971
972
973
974
975
976
977
978
979
980
981
            is_fullgraph = (
                self.compilation_config.use_inductor_graph_partition
                or len(self.compilation_config.splitting_ops) == 0
            )
            if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
                if "-rms_norm" not in self.compilation_config.custom_ops:
                    self.compilation_config.custom_ops.append("+rms_norm")
                else:
                    regime = (
                        "Dynamo partition"
                        if not is_fullgraph
                        else "pipeline parallelism"
                    )
                    logger.warning_once(
982
                        "Sequence parallelism not supported with "
983
984
985
986
987
                        "native rms_norm when using %s, "
                        "this will likely lead to an error.",
                        regime,
                    )

988
        # final check of cudagraph mode after all possible updates
989
        if current_platform.is_cuda_alike():
990
991
992
993
            if (
                self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                and self.model_config is not None
                and not self.model_config.disable_cascade_attn
994
                and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()  # noqa: E501
995
            ):
996
997
998
                logger.warning_once(
                    "No piecewise cudagraph for executing cascade attention."
                    " Will fall back to eager execution if a batch runs "
999
                    "into cascade attentions."
1000
1001
1002
                )

            if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
1003
1004
                assert self.compilation_config.mode == CompilationMode.VLLM_COMPILE, (
                    "Compilation mode should be CompilationMode.VLLM_COMPILE "
1005
                    "when cudagraph_mode piecewise cudagraphs is used, "
1006
                    f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
1007
                )
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant

        if (
            self.model_config
            and vllm_is_batch_invariant()
            and not self.model_config.disable_cascade_attn
        ):
            self.model_config.disable_cascade_attn = True
            logger.warning_once(
                "Disabling cascade attention when VLLM_BATCH_INVARIANT is enabled.",
                scope="local",
            )
1020

1021
        if self.parallel_config.use_ubatching:
1022
            a2a_backend = self.parallel_config.all2all_backend
1023
1024
1025
1026
            assert a2a_backend in [
                "deepep_low_latency",
                "deepep_high_throughput",
            ], (
1027
1028
                "Microbatching currently only supports the deepep_low_latency and "
                f"deepep_high_throughput all2all backend. {a2a_backend} is not "
1029
1030
1031
                "supported. To fix use --all2all-backend=deepep_low_latency or "
                "--all2all-backend=deepep_high_throughput and install the DeepEP"
                " kernels."
1032
            )
1033
1034
1035

            if not self.model_config.disable_cascade_attn:
                self.model_config.disable_cascade_attn = True
1036
                logger.warning_once("Disabling cascade attention when DBO is enabled.")
1037
1038
1039
1040

        if not self.instance_id:
            self.instance_id = random_uuid()[:5]

1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
        # Hybrid KV cache manager (HMA) runtime rules:
        # - Explicit enable (--no-disable-kv-cache-manager): error if runtime
        #   disables it
        # - No preference: auto-disable for unsupported features (e.g. kv connector)
        # - Explicit disable (--disable-kv-cache-manager): always respect it
        need_disable_hybrid_kv_cache_manager = False
        # logger should only print warning message for hybrid models. As we
        # can't know whether the model is hybrid or not now, so we don't log
        # warning message here and will log it later.
        if not current_platform.support_hybrid_kv_cache():
            # Hybrid KV cache manager is not supported on non-GPU platforms.
            need_disable_hybrid_kv_cache_manager = True
        if self.kv_events_config is not None:
            # Hybrid KV cache manager is not compatible with KV events.
            need_disable_hybrid_kv_cache_manager = True
        if (
            self.model_config is not None
            and self.model_config.attention_chunk_size is not None
        ):
            if (
                self.speculative_config is not None
                and self.speculative_config.use_eagle()
            ):
                # Hybrid KV cache manager is not yet supported with chunked
                # local attention + eagle.
                need_disable_hybrid_kv_cache_manager = True
            elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
                logger.warning(
                    "There is a latency regression when using chunked local"
                    " attention with the hybrid KV cache manager. Disabling"
                    " it, by default. To enable it, set the environment "
                    "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
                )
                # Hybrid KV cache manager is not yet supported with chunked
                # local attention.
                need_disable_hybrid_kv_cache_manager = True

        if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
            # Default to disable HMA, but only if the user didn't express a preference.
1080
            if self.kv_transfer_config is not None:
1081
1082
                # NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
                need_disable_hybrid_kv_cache_manager = True
1083
1084
1085
1086
1087
1088
1089
                logger.warning(
                    "Turning off hybrid kv cache manager because "
                    "`--kv-transfer-config` is set. This will reduce the "
                    "performance of vLLM on LLMs with sliding window attention "
                    "or Mamba attention. If you are a developer of kv connector"
                    ", please consider supporting hybrid kv cache manager for "
                    "your connector by making sure your connector is a subclass"
1090
1091
                    " of `SupportsHMA` defined in kv_connector/v1/base.py and"
                    " use --no-disable-hybrid-kv-cache-manager to start vLLM."
1092
                )
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
            self.scheduler_config.disable_hybrid_kv_cache_manager = (
                need_disable_hybrid_kv_cache_manager
            )
        elif (
            self.scheduler_config.disable_hybrid_kv_cache_manager is False
            and need_disable_hybrid_kv_cache_manager
        ):
            raise ValueError(
                "Hybrid KV cache manager was explicitly enabled but is not "
                "supported in this configuration. Consider omitting the "
                "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide"
                " automatically."
            )

        if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
            # Default to enable HMA if not explicitly disabled by user or logic above.
            self.scheduler_config.disable_hybrid_kv_cache_manager = False
1110

1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
        if self.cache_config.mamba_cache_mode == "align":
            assert (
                self.cache_config.block_size
                <= self.scheduler_config.max_num_batched_tokens
            ), (
                "In Mamba cache align mode, block_size "
                f"({self.cache_config.block_size}) must be <= "
                "max_num_batched_tokens "
                f"({self.scheduler_config.max_num_batched_tokens})."
            )
            if self.scheduler_config.long_prefill_token_threshold > 0:
                assert (
                    self.scheduler_config.long_prefill_token_threshold
                    >= self.cache_config.block_size
                )
            assert not self.scheduler_config.disable_chunked_mm_input, (
                "Chunked MM input is required because we need the flexibility to "
                "schedule a multiple of block_size tokens even if they are in the "
                "middle of a mm input"
            )
1131
        if self.compilation_config.debug_dump_path:
1132
            self.compilation_config.debug_dump_path = (
1133
                self.compilation_config.debug_dump_path.absolute().expanduser()
1134
            )
1135
1136
1137
1138
1139
        if envs.VLLM_DEBUG_DUMP_PATH is not None:
            env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
            if self.compilation_config.debug_dump_path:
                logger.warning(
                    "Config-specified debug dump path is overridden"
1140
1141
1142
                    " by VLLM_DEBUG_DUMP_PATH to %s",
                    env_path,
                )
1143
1144
            self.compilation_config.debug_dump_path = env_path

1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
1159
            if "-quant_fp8" not in custom_ops:
1160
1161
                custom_ops.append("+quant_fp8")

1162
1163
1164
        # Handle the KV connector configs
        self._post_init_kv_transfer_config()

1165
    def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
1166
1167
1168
        # remove the sizes that not multiple of tp_size when
        # enable sequence parallelism
        removed_sizes = [
1169
1170
            size
            for size in possible_sizes
1171
1172
1173
1174
1175
1176
            if size % self.parallel_config.tensor_parallel_size != 0
        ]
        if removed_sizes:
            logger.warning(
                "Batch sizes %s are removed because they are not "
                "multiple of tp_size %d when "
1177
1178
1179
1180
                "sequence parallelism is enabled",
                removed_sizes,
                self.parallel_config.tensor_parallel_size,
            )
1181
1182

        return [
1183
1184
            size
            for size in possible_sizes
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
            if size % self.parallel_config.tensor_parallel_size == 0
        ]

    def _set_cudagraph_sizes(self):
        """
        vLLM defines the default candidate list of batch sizes for CUDA graph
        capture as:

        ```python
        max_graph_size = min(max_num_seqs * 2, 512)
1195
1196
        # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
        # up to max_graph_size
1197
        cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
1198
            range(256, max_graph_size + 1, 16))
1199
1200

        In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
1201
        will be the final sizes to capture cudagraph (in ascending order).
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227

        These sizes are used to capture and reuse CUDA graphs for
        performance-critical paths (e.g., decoding). Capturing enables
        significantly faster kernel dispatch by avoiding Python overhead. The
        list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
        most GPUs), which controls the total allowed number of tokens in a
        batch. Since each sequence may have a variable number of tokens, the
        maximum usable batch size will depend on actual sequence lengths.

        Example:
            With `max_num_batched_tokens = 8192`, and typical sequences
            averaging ~32 tokens, most practical batch sizes fall below 256.
            However, the system will still allow capture sizes up to 512 if
            shape and memory permit.

        Note:
            If users explicitly specify cudagraph capture sizes in the
            compilation config, those will override this default logic.
            At runtime:

            - If batch size <= one of the `cudagraph_capture_sizes`, the closest
            padded CUDA graph will be used.
            - If batch size > largest `cudagraph_capture_sizes`, cudagraph will
            not be used.
        """

1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
        if (
            self.model_config is not None
            and not self.model_config.enforce_eager
            and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
            # determine the initial max_cudagraph_capture_size
            max_cudagraph_capture_size = (
                self.compilation_config.max_cudagraph_capture_size
            )
            if max_cudagraph_capture_size is None:
1238
1239
1240
1241
1242
1243
                decode_query_len = 1
                if (
                    self.speculative_config
                    and self.speculative_config.num_speculative_tokens
                ):
                    decode_query_len += self.speculative_config.num_speculative_tokens
1244
                max_cudagraph_capture_size = min(
1245
                    self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
1246
                )
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
            max_num_tokens = self.scheduler_config.max_num_batched_tokens
            max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)

            assert max_cudagraph_capture_size >= 1, (
                "Maximum cudagraph size should be greater than or equal to 1 "
                "when using cuda graph."
            )

            # determine the cudagraph_capture_sizes
            if self.compilation_config.cudagraph_capture_sizes is not None:
                assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
                    "cudagraph_capture_sizes should contain at least one element "
                    "when using cuda graph."
                )
                # de-duplicate the sizes provided by the config
                dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
1263
1264
1265
                cudagraph_capture_sizes = [
                    i for i in dedup_sizes if i <= max_num_tokens
                ]
1266
1267
                # sort to make sure the sizes are in ascending order
                cudagraph_capture_sizes.sort()
1268
            else:
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
                cudagraph_capture_sizes = [
                    i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
                ]
                if max_cudagraph_capture_size >= 8:
                    # Step size 8 for small batch sizes, up to 256(not included)
                    cudagraph_capture_sizes += list(
                        range(8, min(max_cudagraph_capture_size + 1, 256), 8)
                    )
                if max_cudagraph_capture_size >= 256:
                    # Step size 16 for larger batch sizes
                    cudagraph_capture_sizes += list(
                        range(256, max_cudagraph_capture_size + 1, 16)
                    )

1283
1284
            if (
                self.parallel_config.tensor_parallel_size > 1
1285
                and self.compilation_config.pass_config.enable_sp
1286
            ):
1287
1288
                cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
                    cudagraph_capture_sizes
1289
                )
1290

1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
            # user-specific compilation_config.max_cudagraph_capture_size get
            # truncated to valid_max_size when they are inconsistent.
            valid_max_size = (
                cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
            )
            if (
                self.compilation_config.max_cudagraph_capture_size is not None
                and self.compilation_config.max_cudagraph_capture_size != valid_max_size
            ):
                # raise error only when both two flags are user-specified
                # and they are inconsistent with each other
                if self.compilation_config.cudagraph_capture_sizes is not None:
                    raise ValueError(
                        "customized max_cudagraph_capture_size"
                        f"(={self.compilation_config.max_cudagraph_capture_size}) "
                        "should be consistent with the max value of "
                        f"cudagraph_capture_sizes(={valid_max_size})"
                    )

                logger.warning(
                    "Truncating max_cudagraph_capture_size to %d",
                    valid_max_size,
                )
            # always set the final max_cudagraph_capture_size
            self.compilation_config.max_cudagraph_capture_size = valid_max_size

            if self.compilation_config.cudagraph_capture_sizes is not None and len(
                cudagraph_capture_sizes
            ) < len(self.compilation_config.cudagraph_capture_sizes):
                # If users have specified capture sizes, we only need to
                # compare the lens before and after modification since the modified
                # list is only the subset of the original list.
                logger.warning(
                    (
                        "cudagraph_capture_sizes specified in compilation_config"
                        " %s is overridden by config %s"
                    ),
                    self.compilation_config.cudagraph_capture_sizes,
                    cudagraph_capture_sizes,
                )
            # always write back the final sizes
            self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes

        else:
            # no cudagraph in use
            self.compilation_config.max_cudagraph_capture_size = 0
            self.compilation_config.cudagraph_capture_sizes = []

        # complete the remaining process.
        self.compilation_config.post_init_cudagraph_sizes()
1341

1342
1343
1344
1345
1346
1347
1348
    def _set_compile_ranges(self):
        """
        Set the compile ranges for the compilation config.
        """
        compilation_config = self.compilation_config
        computed_compile_ranges_split_points = []

1349
        # The upper bound of the compile ranges is the max_num_batched_tokens.
1350
1351
1352
        # For speculative decoding, the compile range must be extended
        # - Sequential: + 1 * max_num_seqs (one draft token per iteration)
        # - Parallel draft: + num_speculative_tokens * max_num_seqs
1353
1354
        compile_range_end = self.scheduler_config.max_num_batched_tokens
        if compile_range_end is not None:
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
            if self.speculative_config is not None and (
                self.speculative_config.uses_draft_model()
                or self.speculative_config.use_eagle()
            ):
                multiplier = (
                    self.speculative_config.num_speculative_tokens
                    if self.speculative_config.parallel_drafting
                    else 1
                )
                compile_range_end += multiplier * self.scheduler_config.max_num_seqs
1365
1366

            computed_compile_ranges_split_points.append(compile_range_end)
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376

        # Add the compile ranges for flashinfer
        if compilation_config.pass_config.fuse_allreduce_rms:
            tp_size = self.parallel_config.tensor_parallel_size
            max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
            if max_size is not None:
                max_token_num = max_size // (
                    self.model_config.get_hidden_size()
                    * self.model_config.dtype.itemsize
                )
1377
                if compile_range_end is not None and max_token_num < compile_range_end:
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
                    computed_compile_ranges_split_points.append(max_token_num)
                else:
                    logger.debug(
                        "Max num batched tokens below allreduce-rms fusion threshold, "
                        "allreduce-rms fusion will be enabled for all num_tokens."
                    )

        if compilation_config.compile_ranges_split_points is not None:
            for x in compilation_config.compile_ranges_split_points:
                assert isinstance(x, int)
                assert x > 0, f"Invalid compile range split point: {x}"
1389
                if compile_range_end is not None and x < compile_range_end and x > 1:
1390
1391
1392
1393
1394
                    computed_compile_ranges_split_points.append(x)
        compilation_config.compile_ranges_split_points = sorted(
            computed_compile_ranges_split_points
        )

1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
    def try_verify_and_update_config(self):
        if self.model_config is None:
            return

        # Avoid running try_verify_and_update_config multiple times
        if getattr(self.model_config, "config_updated", False):
            return
        self.model_config.config_updated = True

        architecture = self.model_config.architecture
        if architecture is None:
            return

        from vllm.model_executor.models.config import (
1409
1410
1411
1412
            MODELS_CONFIG_MAP,
            HybridAttentionMambaModelConfig,
        )

1413
1414
1415
1416
1417
1418
1419
1420
1421
        cls = MODELS_CONFIG_MAP.get(architecture, None)
        if cls is not None:
            cls.verify_and_update_config(self)

        if self.model_config.is_hybrid:
            HybridAttentionMambaModelConfig.verify_and_update_config(self)

        if self.model_config.convert_type == "classify":
            # Maybe convert ForCausalLM into ForSequenceClassification model.
1422
1423
            from vllm.model_executor.models.adapters import SequenceClassificationConfig

1424
1425
1426
            SequenceClassificationConfig.verify_and_update_config(self)

        if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
1427
1428
            self.model_config.model_weights
        ):
1429
            if self.load_config.load_format == "auto":
1430
1431
1432
1433
                logger.info(
                    "Detected Run:ai model config. "
                    "Overriding `load_format` to 'runai_streamer'"
                )
1434
                self.load_config.load_format = "runai_streamer"
1435
1436
1437
1438
            elif self.load_config.load_format not in (
                "runai_streamer",
                "runai_streamer_sharded",
            ):
1439
1440
                raise ValueError(
                    f"To load a model from S3, 'load_format' "
1441
                    f"must be 'runai_streamer' or 'runai_streamer_sharded', "
1442
1443
1444
                    f"but got '{self.load_config.load_format}'. "
                    f"Model: {self.model_config.model}"
                )
1445

1446
    def compile_debug_dump_path(self) -> Path | None:
1447
        """Returns a rank-aware path for dumping
1448
1449
1450
1451
1452
        torch.compile debug information.
        """
        if self.compilation_config.debug_dump_path is None:
            return None
        tp_rank = self.parallel_config.rank
1453
1454
        dp_rank = self.parallel_config.data_parallel_index
        append_path = f"rank_{tp_rank}_dp_{dp_rank}"
1455
1456
1457
1458
1459
1460
1461
        path = self.compilation_config.debug_dump_path / append_path
        return path

    def __str__(self):
        return (
            f"model={self.model_config.model!r}, "
            f"speculative_config={self.speculative_config!r}, "
1462
1463
1464
            f"tokenizer={self.model_config.tokenizer!r}, "
            f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
            f"tokenizer_mode={self.model_config.tokenizer_mode}, "
1465
            f"revision={self.model_config.revision}, "
1466
            f"tokenizer_revision={self.model_config.tokenizer_revision}, "
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
            f"trust_remote_code={self.model_config.trust_remote_code}, "
            f"dtype={self.model_config.dtype}, "
            f"max_seq_len={self.model_config.max_model_len}, "
            f"download_dir={self.load_config.download_dir!r}, "
            f"load_format={self.load_config.load_format}, "
            f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, "  # noqa
            f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, "  # noqa
            f"data_parallel_size={self.parallel_config.data_parallel_size}, "  # noqa
            f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, "  # noqa
            f"quantization={self.model_config.quantization}, "
            f"enforce_eager={self.model_config.enforce_eager}, "
1478
            f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, "  # noqa
1479
1480
1481
1482
1483
1484
1485
            f"kv_cache_dtype={self.cache_config.cache_dtype}, "
            f"device_config={self.device_config.device}, "
            f"structured_outputs_config={self.structured_outputs_config!r}, "
            f"observability_config={self.observability_config!r}, "
            f"seed={self.model_config.seed}, "
            f"served_model_name={self.model_config.served_model_name}, "
            f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
1486
            f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, "  # noqa
1487
            f"pooler_config={self.model_config.pooler_config!r}, "
1488
1489
            f"compilation_config={self.compilation_config!r}"
        )
1490

1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
    @model_validator(mode="after")
    def validate_mamba_block_size(self) -> "VllmConfig":
        if self.model_config is None:
            return self
        mamba_block_size_is_set = (
            self.cache_config.mamba_block_size is not None
            and self.cache_config.mamba_block_size != self.model_config.max_model_len
        )
        if mamba_block_size_is_set and not self.cache_config.enable_prefix_caching:
            raise ValueError(
                "--mamba-block-size can only be set with --enable-prefix-caching"
            )
        return self

1505

1506
1507
_current_vllm_config: VllmConfig | None = None
_current_prefix: str | None = None
1508
1509
1510


@contextmanager
1511
def set_current_vllm_config(
1512
    vllm_config: VllmConfig, check_compile=False, prefix: str | None = None
1513
):
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
    """
    Temporarily set the current vLLM config.
    Used during model initialization.
    We save the current vLLM config in a global variable,
    so that all modules can access it, e.g. custom ops
    can access the vLLM config to determine how to dispatch.
    """
    global _current_vllm_config, _current_prefix
    old_vllm_config = _current_vllm_config
    old_prefix = _current_prefix
    from vllm.compilation.counter import compilation_counter
1525

1526
1527
    num_models_seen = compilation_counter.num_models_seen
    try:
1528
1529
1530
1531
1532
        # Clear the compilation config cache when context changes.
        # This is needed since the old config may have been accessed
        # and cached before the new config is set.
        get_cached_compilation_config.cache_clear()

1533
1534
1535
1536
1537
1538
1539
1540
1541
        _current_vllm_config = vllm_config
        _current_prefix = prefix
        yield
    except Exception:
        raise
    else:
        if check_compile:
            vllm_config.compilation_config.custom_op_log_check()

1542
1543
        if (
            check_compile
1544
            and vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
1545
1546
            and compilation_counter.num_models_seen == num_models_seen
        ):
1547
1548
1549
1550
1551
1552
1553
1554
1555
            # If the model supports compilation,
            # compilation_counter.num_models_seen should be increased
            # by at least 1.
            # If it is not increased, it means the model does not support
            # compilation (does not have @support_torch_compile decorator).
            logger.warning(
                "`torch.compile` is turned on, but the model %s"
                " does not support it. Please open an issue on GitHub"
                " if you want it to be supported.",
1556
1557
                vllm_config.model_config.model,
            )
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
    finally:
        _current_vllm_config = old_vllm_config
        _current_prefix = old_prefix
        # Clear the compilation config cache when context changes
        get_cached_compilation_config.cache_clear()


@lru_cache(maxsize=1)
def get_cached_compilation_config():
    """Cache config to avoid repeated calls to get_current_vllm_config()"""
    return get_current_vllm_config().compilation_config


def get_current_vllm_config() -> VllmConfig:
    if _current_vllm_config is None:
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
        raise AssertionError(
            "Current vLLM config is not set. This typically means "
            "get_current_vllm_config() was called outside of a "
            "set_current_vllm_config() context, or a CustomOp was instantiated "
            "at module import time or model forward time when config is not set. "
            "For tests that directly test custom ops/modules, use the "
            "'default_vllm_config' pytest fixture from tests/conftest.py."
        )
    return _current_vllm_config


def get_current_vllm_config_or_none() -> VllmConfig | None:
1585
1586
1587
1588
1589
1590
1591
    return _current_vllm_config


T = TypeVar("T")


def get_layers_from_vllm_config(
1592
1593
    vllm_config: VllmConfig,
    layer_type: type[T],
1594
    layer_names: list[str] | None = None,
1595
) -> dict[str, T]:
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
    """
    Get layers from the vLLM config.

    Args:
        vllm_config: The vLLM config.
        layer_type: The type of the layer to get.
        layer_names: The names of the layers to get. If None, return all layers.
    """

    if layer_names is None:
1606
        layer_names = list(vllm_config.compilation_config.static_forward_context.keys())
1607
1608
1609
1610
1611
1612
1613
1614

    forward_context = vllm_config.compilation_config.static_forward_context

    return {
        layer_name: forward_context[layer_name]
        for layer_name in layer_names
        if isinstance(forward_context[layer_name], layer_type)
    }