vllm.py 63.4 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
import getpass
6
7
import json
import os
8
9
import tempfile
import threading
10
import time
11
from contextlib import contextmanager
12
from dataclasses import is_dataclass, replace
13
from datetime import datetime
14
from enum import IntEnum
15
16
from functools import lru_cache
from pathlib import Path
17
from typing import TYPE_CHECKING, Any, TypeVar, get_args
18
19

import torch
20
from pydantic import ConfigDict, Field, model_validator
21
22
23
from pydantic.dataclasses import dataclass

import vllm.envs as envs
24
from vllm.config.speculative import EagleModelTypes
25
from vllm.logger import enable_trace_function_call, init_logger
26
27
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
28
from vllm.utils.hashing import safe_hash
29

30
from .attention import AttentionConfig
31
from .cache import CacheConfig
32
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
33
from .device import DeviceConfig
34
from .ec_transfer import ECTransferConfig
35
36
37
38
39
40
41
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .parallel import ParallelConfig
42
from .profiler import ProfilerConfig
43
44
45
46
47
48
49
50
from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config

if TYPE_CHECKING:
    from transformers import PretrainedConfig

51
    from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
52
    from vllm.v1.kv_cache_interface import KVCacheConfig
53
54
55
56
57
else:
    PretrainedConfig = Any

    QuantizationConfig = Any

58
59
    KVCacheConfig = Any

60
61
62
logger = init_logger(__name__)


63
64
65
66
67
68
69
class OptimizationLevel(IntEnum):
    """Optimization level enum."""

    O0 = 0
    """O0 : No optimization. no compilation, no cudagraphs, no other
    optimization, just starting up immediately"""
    O1 = 1
70
    """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    cudagraphs"""
    O2 = 2
    """O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
    O3 = 3
    """O3: Currently the same as -O2s."""


IS_QUANTIZED = False
IS_DENSE = False
# The optimizations that depend on these properties currently set to False
# in all cases.
# if model_config is not None:
#     IS_QUANTIZED = lambda c: c.model_config.is_quantized()
#     IS_DENSE = lambda c: not c.model_config.is_model_moe()
# See https://github.com/vllm-project/vllm/issues/25689.


88
89
90
91
def enable_norm_fusion(cfg: "VllmConfig") -> bool:
    """Enable if either RMS norm or quant FP8 custom op is active;
    otherwise Inductor handles fusion."""

92
93
94
95
96
    return cfg.compilation_config.is_custom_op_enabled(
        "rms_norm"
    ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")


97
98
99
100
101
102
103
104
def enable_act_fusion(cfg: "VllmConfig") -> bool:
    """Enable if either SiLU+Mul or quant FP8 custom op is active;
    otherwise Inductor handles fusion."""
    return cfg.compilation_config.is_custom_op_enabled(
        "silu_and_mul"
    ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")


105
106
107
OPTIMIZATION_LEVEL_00 = {
    "compilation_config": {
        "pass_config": {
108
109
110
111
112
113
114
            "eliminate_noops": False,
            "fuse_norm_quant": False,
            "fuse_act_quant": False,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": False,
            "enable_sp": False,
            "fuse_gemm_comms": False,
115
116
117
118
119
120
121
122
        },
        "cudagraph_mode": CUDAGraphMode.NONE,
        "use_inductor_graph_partition": False,
    },
}
OPTIMIZATION_LEVEL_01 = {
    "compilation_config": {
        "pass_config": {
123
124
125
126
127
128
129
            "eliminate_noops": True,
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": False,
            "enable_sp": False,
            "fuse_gemm_comms": False,
130
131
132
133
134
135
136
137
        },
        "cudagraph_mode": CUDAGraphMode.PIECEWISE,
        "use_inductor_graph_partition": False,
    },
}
OPTIMIZATION_LEVEL_02 = {
    "compilation_config": {
        "pass_config": {
138
139
140
141
142
143
144
            "eliminate_noops": True,
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": IS_QUANTIZED,
            "enable_sp": IS_DENSE,
            "fuse_gemm_comms": IS_DENSE,
145
146
147
148
149
150
151
152
        },
        "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
        "use_inductor_graph_partition": False,
    },
}
OPTIMIZATION_LEVEL_03 = {
    "compilation_config": {
        "pass_config": {
153
154
155
156
157
158
159
            "eliminate_noops": True,
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": IS_QUANTIZED,
            "enable_sp": IS_DENSE,
            "fuse_gemm_comms": IS_DENSE,
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        },
        "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
        "use_inductor_graph_partition": False,
    },
}

OPTIMIZATION_LEVEL_TO_CONFIG = {
    OptimizationLevel.O0: OPTIMIZATION_LEVEL_00,
    OptimizationLevel.O1: OPTIMIZATION_LEVEL_01,
    OptimizationLevel.O2: OPTIMIZATION_LEVEL_02,
    OptimizationLevel.O3: OPTIMIZATION_LEVEL_03,
}


174
175
176
177
178
179
180
181
182
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
    """Dataclass which contains all vllm-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    # TODO: use default_factory once default constructing ModelConfig doesn't
    # try to download a model
183
    model_config: ModelConfig = Field(default=None)
184
    """Model configuration."""
185
    cache_config: CacheConfig = Field(default_factory=CacheConfig)
186
    """Cache configuration."""
187
    parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
188
    """Parallel configuration."""
189
190
191
    scheduler_config: SchedulerConfig = Field(
        default_factory=SchedulerConfig.default_factory,
    )
192
    """Scheduler configuration."""
193
    device_config: DeviceConfig = Field(default_factory=DeviceConfig)
194
    """Device configuration."""
195
    load_config: LoadConfig = Field(default_factory=LoadConfig)
196
    """Load configuration."""
197
198
    attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
    """Attention configuration."""
199
    lora_config: LoRAConfig | None = None
200
    """LoRA configuration."""
201
    speculative_config: SpeculativeConfig | None = None
202
    """Speculative decoding configuration."""
203
    structured_outputs_config: StructuredOutputsConfig = Field(
204
205
        default_factory=StructuredOutputsConfig
    )
206
    """Structured outputs configuration."""
207
208
209
    observability_config: ObservabilityConfig = Field(
        default_factory=ObservabilityConfig
    )
210
    """Observability configuration."""
211
    quant_config: QuantizationConfig | None = None
212
    """Quantization configuration."""
213
    compilation_config: CompilationConfig = Field(default_factory=CompilationConfig)
214
215
    """`torch.compile` and cudagraph capture configuration for the model.

216
217
    As a shorthand, one can append compilation arguments via
    -cc.parameter=argument such as `-cc.mode=3` (same as `-cc='{"mode":3}'`).
218
219

    You can specify the full compilation config like so:
220
    `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
221
    """
222
223
    profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig)
    """Profiling configuration."""
224
    kv_transfer_config: KVTransferConfig | None = None
225
    """The configurations for distributed KV cache transfer."""
226
    kv_events_config: KVEventsConfig | None = None
227
    """The configurations for event publishing."""
228
229
    ec_transfer_config: ECTransferConfig | None = None
    """The configurations for distributed EC cache transfer."""
230
231
232
    # some opaque config, only used to provide additional information
    # for the hash computation, mainly used for testing, debugging or out of
    # tree config registration.
233
    additional_config: dict | SupportsHash = Field(default_factory=dict)
234
235
236
237
238
    """Additional config for specified platform. Different platforms may
    support different configs. Make sure the configs are valid for the platform
    you are using. Contents must be hashable."""
    instance_id: str = ""
    """The ID of the vLLM instance."""
239
240
241
242
243
    optimization_level: OptimizationLevel = OptimizationLevel.O2
    """The optimization level. These levels trade startup time cost for
    performance, with -O0 having the best startup time and -O3 having the best
    performance. -02 is used by defult. See  OptimizationLevel for full
    description."""
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []

        # summarize vllm config
        vllm_factors: list[Any] = []
        from vllm import __version__
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        vllm_factors.append(__version__)
        if self.model_config:
            vllm_factors.append(self.model_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.cache_config:
            vllm_factors.append(self.cache_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.parallel_config:
            vllm_factors.append(self.parallel_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.scheduler_config:
            vllm_factors.append(self.scheduler_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.device_config:
            vllm_factors.append(self.device_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.load_config:
            vllm_factors.append(self.load_config.compute_hash())
        else:
            vllm_factors.append("None")
288
289
290
291
        if self.attention_config:
            vllm_factors.append(self.attention_config.compute_hash())
        else:
            vllm_factors.append("None")
292
293
294
295
296
297
298
299
300
301
        if self.lora_config:
            vllm_factors.append(self.lora_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.speculative_config:
            vllm_factors.append(self.speculative_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.structured_outputs_config:
            vllm_factors.append(self.structured_outputs_config.compute_hash())
302
303
        if self.profiler_config:
            vllm_factors.append(self.profiler_config.compute_hash())
304
305
        else:
            vllm_factors.append("None")
306
        vllm_factors.append(self.observability_config.compute_hash())
307
308
309
310
311
312
313
314
315
316
        if self.quant_config:
            pass  # should be captured by model_config.quantization
        if self.compilation_config:
            vllm_factors.append(self.compilation_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.kv_transfer_config:
            vllm_factors.append(self.kv_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
317
318
319
320
        if self.ec_transfer_config:
            vllm_factors.append(self.ec_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
321
322
        if self.additional_config:
            if isinstance(additional_config := self.additional_config, dict):
323
                additional_config_hash = safe_hash(
324
325
326
327
328
329
330
331
332
333
                    json.dumps(additional_config, sort_keys=True).encode(),
                    usedforsecurity=False,
                ).hexdigest()
            else:
                additional_config_hash = additional_config.compute_hash()
            vllm_factors.append(additional_config_hash)
        else:
            vllm_factors.append("None")
        factors.append(vllm_factors)

334
335
336
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
            :10
        ]
337
338
339
        return hash_str

    def pad_for_cudagraph(self, batch_size: int) -> int:
340
        # if batch_size > self.compilation_config.max_cudagraph_capture_size,
341
342
        # it should raise an IndexError.
        # the caller should make sure the batch_size is within the range,
343
        # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
344
345
        return self.compilation_config.bs_to_padded_graph_size[batch_size]

346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    @property
    def needs_dp_coordinator(self) -> bool:
        """
        Determine if the DPCoordinator process is needed.

        The DPCoordinator is needed in two cases:
        1. For MoE models with DP > 1: to handle wave coordination
           (even in external LB mode, since wave coordination runs in the coordinator)
        2. For non-MoE models in internal/hybrid LB mode: to collect and publish
           queue stats for load balancing across DP ranks

        Returns:
            True if DPCoordinator process is needed, False otherwise.
        """

        # For non-MoE models, only need coordinator in internal/hybrid LB mode
        # (for stats collection).
        return self.parallel_config.data_parallel_size > 1 and (
            self.model_config is None
            or self.model_config.is_moe
            or not self.parallel_config.data_parallel_external_lb
        )

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    def enable_trace_function_call_for_thread(self) -> None:
        """
        Set up function tracing for the current thread,
        if enabled via the `VLLM_TRACE_FUNCTION` environment variable.
        """
        if envs.VLLM_TRACE_FUNCTION:
            tmp_dir = tempfile.gettempdir()
            # add username to tmp_dir to avoid permission issues
            tmp_dir = os.path.join(tmp_dir, getpass.getuser())
            filename = (
                f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                f"_thread_{threading.get_ident()}_at_{datetime.now()}.log"
            ).replace(" ", "_")
            log_path = os.path.join(
                tmp_dir,
                "vllm",
                f"vllm-instance-{self.instance_id}",
                filename,
            )
            os.makedirs(os.path.dirname(log_path), exist_ok=True)
            enable_trace_function_call(log_path)

391
392
    @staticmethod
    def _get_quantization_config(
393
        model_config: ModelConfig, load_config: LoadConfig
394
    ) -> QuantizationConfig | None:
395
396
        """Get the quantization config."""
        from vllm.platforms import current_platform
397

398
        if model_config.quantization is not None:
399
400
            from vllm.model_executor.model_loader.weight_utils import get_quant_config

401
402
403
404
405
406
407
408
409
410
            quant_config = get_quant_config(model_config, load_config)
            capability_tuple = current_platform.get_device_capability()

            if capability_tuple is not None:
                capability = capability_tuple.to_int()
                if capability < quant_config.get_min_capability():
                    raise ValueError(
                        f"The quantization method {model_config.quantization} "
                        "is not supported for the current GPU. Minimum "
                        f"capability: {quant_config.get_min_capability()}. "
411
412
                        f"Current capability: {capability}."
                    )
413
414
415
416
417
            supported_dtypes = quant_config.get_supported_act_dtypes()
            if model_config.dtype not in supported_dtypes:
                raise ValueError(
                    f"{model_config.dtype} is not supported for quantization "
                    f"method {model_config.quantization}. Supported dtypes: "
418
419
                    f"{supported_dtypes}"
                )
420
421
422
423
424
425
            quant_config.maybe_update_config(model_config.model)
            return quant_config
        return None

    @staticmethod
    def get_quantization_config(
426
        model_config: ModelConfig, load_config: LoadConfig
427
    ) -> QuantizationConfig | None:
428
429
430
431
        import copy

        # For some reason, the _ version of this modifies the model_config
        # object, so using deepcopy to avoid this problem.
432
433
434
        return VllmConfig._get_quantization_config(
            copy.deepcopy(model_config), load_config
        )
435
436
437
438

    def with_hf_config(
        self,
        hf_config: PretrainedConfig,
439
        architectures: list[str] | None = None,
440
441
442
443
444
445
446
    ) -> "VllmConfig":
        if architectures is not None:
            hf_config = copy.deepcopy(hf_config)
            hf_config.architectures = architectures

        model_config = copy.deepcopy(self.model_config)
        model_config.hf_config = hf_config
447
        model_config.model_arch_config = model_config.get_model_arch_config()
448
449
450

        return replace(self, model_config=model_config)

451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    def _set_config_default(self, config_obj: Any, key: str, value: Any) -> None:
        """Set config attribute to default if not already set by user.

        Args:
            config_obj: Configuration object to update.
            key: Attribute name.
            value: Default value (static or callable).
        """
        if getattr(config_obj, key) is None:
            # Some config values are known before initialization and are
            # hard coded.
            # Other values depend on the user given configuration, so they are
            # implemented with lambda functions and decided at run time.
            setattr(config_obj, key, value(self) if callable(value) else value)

    def _apply_optimization_level_defaults(self, defaults: dict[str, Any]) -> None:
        """Apply optimization level defaults using self as root.

        Recursively applies values from defaults into nested config objects.
        Only fields present in defaults are overwritten.

        If the user configuration does not specify a value for a default field
        and if the default field is still None after all user selections are
        applied, then default values will be applied to the field. User speciied
        fields will not be overridden by the default.

        Args:
            defaults: Dictionary of default values to apply.
        """

        def apply_recursive(config_obj: Any, config_defaults: dict[str, Any]) -> None:
            """Recursively apply defaults to config_obj, using self as root."""
            for key, value in config_defaults.items():
                if not hasattr(config_obj, key):
                    continue

                current = getattr(config_obj, key)
                if isinstance(value, dict) and is_dataclass(current):
                    apply_recursive(current, value)
                else:
                    self._set_config_default(config_obj, key, value)

        apply_recursive(self, defaults)

495
496
497
498
499
500
    def _post_init_kv_transfer_config(self) -> None:
        """Update KVTransferConfig based on top-level configs in VllmConfig.

        Right now, this function reads the offloading settings from
        CacheConfig and configures the KVTransferConfig accordingly.
        """
501
502
        # KV offloading is only activated when kv_offloading_size is set.
        if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
503
504
            return

505
506
        kv_offloading_backend = self.cache_config.kv_offloading_backend

507
508
509
510
511
512
513
514
515
516
517
        # If no KVTransferConfig is provided, create a default one.
        if self.kv_transfer_config is None:
            self.kv_transfer_config = KVTransferConfig()
        num_kv_ranks = (
            self.parallel_config.tensor_parallel_size
            * self.parallel_config.pipeline_parallel_size
        )

        if kv_offloading_backend == "native":
            self.kv_transfer_config.kv_connector = "OffloadingConnector"
            self.kv_transfer_config.kv_connector_extra_config.update(
518
                {"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
519
520
521
522
523
524
525
526
527
528
529
530
            )
        elif kv_offloading_backend == "lmcache":
            self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
            kv_gb_per_rank = kv_offloading_size / num_kv_ranks
            self.kv_transfer_config.kv_connector_extra_config = {
                "lmcache.local_cpu": True,
                "lmcache.max_local_cpu_size": kv_gb_per_rank,
            }

        # This is the same for all backends
        self.kv_transfer_config.kv_role = "kv_both"

531
    def __post_init__(self):
532
        """Verify configs are valid & consistent with each other."""
533

534
535
536
        # To give each torch profile run a unique instance name.
        self.instance_id = f"{time.time_ns()}"

537
538
539
540
        self.try_verify_and_update_config()

        if self.model_config is not None:
            self.model_config.verify_with_parallel_config(self.parallel_config)
541
            self.model_config.verify_dual_chunk_attention_config(self.load_config)
542

543
544
            self.parallel_config.is_moe_model = self.model_config.is_moe

545
546
547
548
549
550
551
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config is not None:
            self.lora_config.verify_with_model_config(self.model_config)

        if self.quant_config is None and self.model_config is not None:
            self.quant_config = VllmConfig._get_quantization_config(
552
553
                self.model_config, self.load_config
            )
554

555
556
557
558
559
560
561
562
563
        executor_backend = self.parallel_config.distributed_executor_backend
        executor_supports_async_sched = executor_backend in (
            "mp",
            "uni",
            "external_launcher",
        )

        if self.scheduler_config.async_scheduling:
            # Async scheduling explicitly enabled, hard fail any incompatibilities.
564
565
            # Currently, async scheduling only support eagle speculative
            # decoding.
566
            if self.speculative_config is not None:
567
568
569
                if self.speculative_config.method not in get_args(EagleModelTypes):
                    raise ValueError(
                        "Currently, async scheduling is only supported "
570
                        "with EAGLE/MTP kind of speculative decoding."
571
572
573
                    )
                if self.speculative_config.disable_padded_drafter_batch:
                    raise ValueError(
574
575
                        "Async scheduling is not compatible with "
                        "disable_padded_drafter_batch=True."
576
                    )
577
578
579
580
581
582
583
584
            if not executor_supports_async_sched:
                raise ValueError(
                    "Currently, async scheduling only supports `mp`, `uni`, or "
                    "`external_launcher` distributed executor backend, but you chose "
                    f"`{executor_backend}`."
                )
        elif self.scheduler_config.async_scheduling is None:
            # Enable async scheduling unless there is an incompatible option.
585
            if (
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
                self.speculative_config is not None
                and self.speculative_config.method not in get_args(EagleModelTypes)
            ):
                logger.warning_once(
                    "Async scheduling not supported with %s-based "
                    "speculative decoding and will be disabled.",
                    self.speculative_config.method,
                    scope="local",
                )
                self.scheduler_config.async_scheduling = False
            elif (
                self.speculative_config is not None
                and self.speculative_config.disable_padded_drafter_batch
            ):
                logger.warning_once(
                    "Async scheduling is not compatible with "
                    "disable_padded_drafter_batch=True and will be disabled.",
                    scope="local",
                )
605
                self.scheduler_config.async_scheduling = False
606
            elif not executor_supports_async_sched:
607
                logger.warning_once(
608
609
610
611
                    "Async scheduling will be disabled because it is not supported "
                    "with the `%s` distributed executor backend (only `mp`, `uni`, and "
                    "`external_launcher` are supported).",
                    executor_backend,
612
                    scope="local",
613
614
615
616
617
                )
                self.scheduler_config.async_scheduling = False
            else:
                self.scheduler_config.async_scheduling = True

618
619
620
621
622
        logger.info_once(
            "Asynchronous scheduling is %s.",
            "enabled" if self.scheduler_config.async_scheduling else "disabled",
        )

623
624
        if self.parallel_config.disable_nccl_for_dp_synchronization is None:
            if self.scheduler_config.async_scheduling:
625
626
627
628
629
630
631
632
                if self.parallel_config.data_parallel_size > 1 and (
                    self.model_config is None or self.model_config.is_moe
                ):
                    logger.info_once(
                        "Disabling NCCL for DP synchronization "
                        "when using async scheduling.",
                        scope="local",
                    )
633
634
635
636
                self.parallel_config.disable_nccl_for_dp_synchronization = True
            else:
                self.parallel_config.disable_nccl_for_dp_synchronization = False

637
        from vllm.platforms import current_platform
638
639
640

        if (
            self.model_config is not None
641
            and self.scheduler_config.enable_chunked_prefill
642
643
644
            and self.model_config.dtype == torch.float32
            and current_platform.get_device_capability() == (7, 5)
        ):
645
646
647
            logger.warning_once(
                "Turing devices tensor cores do not support float32 matmul. "
                "To workaround this limitation, vLLM will set 'ieee' input "
648
649
                "precision for chunked prefill triton kernels."
            )
650

651
652
653
654
655
656
657
658
659
660
661
662
663
        if (
            self.optimization_level > OptimizationLevel.O0
            and self.model_config is not None
            and self.model_config.enforce_eager
        ):
            logger.warning("Enforce eager set, overriding optimization level to -O0")
            self.optimization_level = OptimizationLevel.O0

        if self.compilation_config.backend == "eager" or (
            self.compilation_config.mode is not None
            and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
        ):
            logger.warning(
664
665
666
                "Inductor compilation was disabled by user settings, "
                "optimizations settings that are only active during "
                "inductor compilation will be ignored."
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
            )

        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
            if "-quant_fp8" not in custom_ops:
                custom_ops.append("+quant_fp8")

686
        if self.compilation_config.mode is None:
687
            if self.optimization_level > OptimizationLevel.O0:
688
                self.compilation_config.mode = CompilationMode.VLLM_COMPILE
689
            else:
690
                self.compilation_config.mode = CompilationMode.NONE
691
692
693
694

        if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
            if (
                self.compilation_config.backend == "inductor"
695
                and self.compilation_config.mode != CompilationMode.NONE
696
697
698
699
            ):
                self.compilation_config.custom_ops.append("none")
            else:
                self.compilation_config.custom_ops.append("all")
700

701
702
        default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
        self._apply_optimization_level_defaults(default_config)
703

704
        if (
705
            self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
706
707
708
709
710
711
712
713
714
715
            and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
        ):
            logger.info(
                "Cudagraph mode %s is not compatible with compilation mode %s."
                "Overriding to NONE.",
                self.compilation_config.cudagraph_mode,
                self.compilation_config.mode,
            )
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

716
717
        # async tp is built on top of sequence parallelism
        # and requires it to be enabled.
718
719
720
        if self.compilation_config.pass_config.fuse_gemm_comms:
            self.compilation_config.pass_config.enable_sp = True
        if self.compilation_config.pass_config.enable_sp:
721
722
723
724
725
726
            if "-rms_norm" in self.compilation_config.custom_ops:
                logger.warning(
                    "RMS norm force disabled, sequence parallelism might break"
                )
            else:
                self.compilation_config.custom_ops.append("+rms_norm")
727
728

        if current_platform.support_static_graph_mode():
729
            # if cudagraph_mode has full cudagraphs, we need to check support
730
731
732
733
734
            if model_config := self.model_config:
                if (
                    self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                    and model_config.pooler_config is not None
                ):
735
                    logger.warning_once(
736
                        "Pooling models do not support full cudagraphs. "
737
738
739
                        "Overriding cudagraph_mode to PIECEWISE."
                    )
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
740
741
742
743
744
745
746
747
748
749
750
751
                elif (
                    model_config.is_encoder_decoder
                    and self.compilation_config.cudagraph_mode
                    not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
                ):
                    logger.info_once(
                        "Encoder-decoder models do not support %s. "
                        "Overriding cudagraph_mode to FULL_DECODE_ONLY.",
                        self.compilation_config.cudagraph_mode.name,
                    )
                    self.compilation_config.cudagraph_mode = (
                        CUDAGraphMode.FULL_DECODE_ONLY
752
                    )
753
754

            # disable cudagraph when enforce eager execution
755
            if self.model_config is not None and self.model_config.enforce_eager:
756
757
                logger.info("Cudagraph is disabled under eager mode")
                self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
758
759
760
                # override related settings when enforce eager
                self.compilation_config.max_cudagraph_capture_size = 0
                self.compilation_config.cudagraph_capture_sizes = []
761
            else:
762
763
764
765
766
767
768
                self.compilation_config.cudagraph_num_of_warmups = 1

            self._set_cudagraph_sizes()
        else:
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

        if self.cache_config.kv_sharing_fast_prefill:
769
770
771
772
            if (
                self.speculative_config is not None
                and self.speculative_config.use_eagle()
            ):
773
                raise ValueError(
774
775
776
                    "Fast prefill optimization for KV sharing is not "
                    "compatible with EAGLE as EAGLE requires correct logits "
                    "for all tokens while fast prefill gives incorrect logits "
777
778
                    "for prompt tokens."
                )
779
780
781

            logger.warning_once(
                "--kv-sharing-fast-prefill requires changes on model side for "
782
                "correctness and to realize prefill savings."
783
            )
784
785
        # TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
        self._set_compile_ranges()
786

787
788
789
790
791
792
793
794
795
796
        if (
            self.model_config
            and self.model_config.architecture == "WhisperForConditionalGeneration"
            and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
        ):
            logger.warning(
                "Whisper is known to have issues with "
                "forked workers. If startup is hanging, "
                "try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
                "to 'spawn'."
797
            )
798

799
800
801
802
803
        if (
            self.kv_events_config is not None
            and self.kv_events_config.enable_kv_cache_events
            and not self.cache_config.enable_prefix_caching
        ):
804
            logger.warning(
805
                "KV cache events are on, but prefix caching is not enabled. "
806
807
808
809
810
811
812
813
                "Use --enable-prefix-caching to enable."
            )
        if (
            self.kv_events_config is not None
            and self.kv_events_config.publisher != "null"
            and not self.kv_events_config.enable_kv_cache_events
        ):
            logger.warning(
814
815
816
                "KV cache events are disabled, "
                "but the scheduler is configured to publish them. "
                "Modify KVEventsConfig.enable_kv_cache_events "
817
818
                "to True to enable."
            )
819
820
        current_platform.check_and_update_config(self)

821
822
        # If DCP, ensure the block size is right.
        if self.parallel_config.decode_context_parallel_size > 1:
823
824
825
826
827
828
829
830
831
832
833
834
            if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
                self.parallel_config.cp_kv_cache_interleave_size
                != self.parallel_config.dcp_kv_cache_interleave_size
            ):
                self.parallel_config.cp_kv_cache_interleave_size = (
                    self.parallel_config.dcp_kv_cache_interleave_size
                )
                logger.warning_once(
                    "cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
                    "_interleave_size. And dcp-kv-cache-interleave-size will be "
                    "deprecated when PCP is fully supported."
                )
835
            assert (
836
                self.parallel_config.cp_kv_cache_interleave_size
837
838
                <= self.cache_config.block_size
                and self.cache_config.block_size
839
                % self.parallel_config.cp_kv_cache_interleave_size
840
841
842
                == 0
            ), (
                f"Block_size({self.cache_config.block_size}) should be greater "
843
844
                "than or equal to and divisible by cp_kv_cache_interleave_size "
                f"({self.parallel_config.cp_kv_cache_interleave_size})."
845
            )
846

847
        # Do this after all the updates to compilation_config.mode
848
849
850
851
852
        effective_dp_size = (
            self.parallel_config.data_parallel_size
            if self.model_config is None or self.model_config.is_moe
            else 1
        )
853
854
        self.compilation_config.set_splitting_ops_for_v1(
            all2all_backend=self.parallel_config.all2all_backend,
855
            data_parallel_size=effective_dp_size,
856
        )
857

858
        if self.compilation_config.pass_config.enable_sp:
859
860
861
862
863
            # With pipeline parallelism or dynamo partitioning,
            # native rms norm tracing errors due to incorrect residual shape.
            # Use custom rms norm to unblock. In the future,
            # the pass will operate on higher-level IR to avoid the issue.
            # TODO: https://github.com/vllm-project/vllm/issues/27894
864
865
866
867
868
869
870
            if self.compilation_config.mode != CompilationMode.VLLM_COMPILE:
                logger.warning(
                    "Sequence parallelism is enabled, but running in wrong "
                    "vllm compile mode: %s.",
                    self.compilation_config.mode,
                )

871
872
873
874
875
876
877
878
879
880
881
882
883
884
            is_fullgraph = (
                self.compilation_config.use_inductor_graph_partition
                or len(self.compilation_config.splitting_ops) == 0
            )
            if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
                if "-rms_norm" not in self.compilation_config.custom_ops:
                    self.compilation_config.custom_ops.append("+rms_norm")
                else:
                    regime = (
                        "Dynamo partition"
                        if not is_fullgraph
                        else "pipeline parallelism"
                    )
                    logger.warning_once(
885
                        "Sequence parallelism not supported with "
886
887
888
889
890
                        "native rms_norm when using %s, "
                        "this will likely lead to an error.",
                        regime,
                    )

891
        # final check of cudagraph mode after all possible updates
892
        if current_platform.is_cuda_alike():
893
894
895
896
            if (
                self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                and self.model_config is not None
                and not self.model_config.disable_cascade_attn
897
                and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()  # noqa: E501
898
            ):
899
900
901
                logger.warning_once(
                    "No piecewise cudagraph for executing cascade attention."
                    " Will fall back to eager execution if a batch runs "
902
                    "into cascade attentions."
903
904
905
                )

            if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
906
907
                assert self.compilation_config.mode == CompilationMode.VLLM_COMPILE, (
                    "Compilation mode should be CompilationMode.VLLM_COMPILE "
908
                    "when cudagraph_mode piecewise cudagraphs is used, "
909
                    f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
910
                )
911

912
        if self.parallel_config.use_ubatching:
913
            a2a_backend = self.parallel_config.all2all_backend
914
915
916
917
            assert a2a_backend in [
                "deepep_low_latency",
                "deepep_high_throughput",
            ], (
918
919
                "Microbatching currently only supports the deepep_low_latency and "
                f"deepep_high_throughput all2all backend. {a2a_backend} is not "
920
921
922
                "supported. To fix use --all2all-backend=deepep_low_latency or "
                "--all2all-backend=deepep_high_throughput and install the DeepEP"
                " kernels."
923
            )
924
925
926

            if not self.model_config.disable_cascade_attn:
                self.model_config.disable_cascade_attn = True
927
                logger.warning_once("Disabling cascade attention when DBO is enabled.")
928
929
930
931

        if not self.instance_id:
            self.instance_id = random_uuid()[:5]

932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
        # Hybrid KV cache manager (HMA) runtime rules:
        # - Explicit enable (--no-disable-kv-cache-manager): error if runtime
        #   disables it
        # - No preference: auto-disable for unsupported features (e.g. kv connector)
        # - Explicit disable (--disable-kv-cache-manager): always respect it
        need_disable_hybrid_kv_cache_manager = False
        # logger should only print warning message for hybrid models. As we
        # can't know whether the model is hybrid or not now, so we don't log
        # warning message here and will log it later.
        if not current_platform.support_hybrid_kv_cache():
            # Hybrid KV cache manager is not supported on non-GPU platforms.
            need_disable_hybrid_kv_cache_manager = True
        if self.kv_events_config is not None:
            # Hybrid KV cache manager is not compatible with KV events.
            need_disable_hybrid_kv_cache_manager = True
        if (
            self.model_config is not None
            and self.model_config.attention_chunk_size is not None
        ):
            if (
                self.speculative_config is not None
                and self.speculative_config.use_eagle()
            ):
                # Hybrid KV cache manager is not yet supported with chunked
                # local attention + eagle.
                need_disable_hybrid_kv_cache_manager = True
            elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
                logger.warning(
                    "There is a latency regression when using chunked local"
                    " attention with the hybrid KV cache manager. Disabling"
                    " it, by default. To enable it, set the environment "
                    "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
                )
                # Hybrid KV cache manager is not yet supported with chunked
                # local attention.
                need_disable_hybrid_kv_cache_manager = True

        if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
            # Default to disable HMA, but only if the user didn't express a preference.
971
            if self.kv_transfer_config is not None:
972
973
                # NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
                need_disable_hybrid_kv_cache_manager = True
974
975
976
977
978
979
980
                logger.warning(
                    "Turning off hybrid kv cache manager because "
                    "`--kv-transfer-config` is set. This will reduce the "
                    "performance of vLLM on LLMs with sliding window attention "
                    "or Mamba attention. If you are a developer of kv connector"
                    ", please consider supporting hybrid kv cache manager for "
                    "your connector by making sure your connector is a subclass"
981
982
                    " of `SupportsHMA` defined in kv_connector/v1/base.py and"
                    " use --no-disable-hybrid-kv-cache-manager to start vLLM."
983
                )
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
            self.scheduler_config.disable_hybrid_kv_cache_manager = (
                need_disable_hybrid_kv_cache_manager
            )
        elif (
            self.scheduler_config.disable_hybrid_kv_cache_manager is False
            and need_disable_hybrid_kv_cache_manager
        ):
            raise ValueError(
                "Hybrid KV cache manager was explicitly enabled but is not "
                "supported in this configuration. Consider omitting the "
                "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide"
                " automatically."
            )

        if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
            # Default to enable HMA if not explicitly disabled by user or logic above.
            self.scheduler_config.disable_hybrid_kv_cache_manager = False
1001

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
        if self.cache_config.mamba_cache_mode == "align":
            if self.scheduler_config.long_prefill_token_threshold > 0:
                assert (
                    self.scheduler_config.long_prefill_token_threshold
                    >= self.cache_config.block_size
                )
            assert not self.scheduler_config.disable_chunked_mm_input, (
                "Chunked MM input is required because we need the flexibility to "
                "schedule a multiple of block_size tokens even if they are in the "
                "middle of a mm input"
            )
1013
        if self.compilation_config.debug_dump_path:
1014
            self.compilation_config.debug_dump_path = (
1015
                self.compilation_config.debug_dump_path.absolute().expanduser()
1016
            )
1017
1018
1019
1020
1021
        if envs.VLLM_DEBUG_DUMP_PATH is not None:
            env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
            if self.compilation_config.debug_dump_path:
                logger.warning(
                    "Config-specified debug dump path is overridden"
1022
1023
1024
                    " by VLLM_DEBUG_DUMP_PATH to %s",
                    env_path,
                )
1025
1026
            self.compilation_config.debug_dump_path = env_path

1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
1041
            if "-quant_fp8" not in custom_ops:
1042
1043
                custom_ops.append("+quant_fp8")

1044
1045
1046
        # Handle the KV connector configs
        self._post_init_kv_transfer_config()

1047
    def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
1048
1049
1050
        # remove the sizes that not multiple of tp_size when
        # enable sequence parallelism
        removed_sizes = [
1051
1052
            size
            for size in possible_sizes
1053
1054
1055
1056
1057
1058
            if size % self.parallel_config.tensor_parallel_size != 0
        ]
        if removed_sizes:
            logger.warning(
                "Batch sizes %s are removed because they are not "
                "multiple of tp_size %d when "
1059
1060
1061
1062
                "sequence parallelism is enabled",
                removed_sizes,
                self.parallel_config.tensor_parallel_size,
            )
1063
1064

        return [
1065
1066
            size
            for size in possible_sizes
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
            if size % self.parallel_config.tensor_parallel_size == 0
        ]

    def _set_cudagraph_sizes(self):
        """
        vLLM defines the default candidate list of batch sizes for CUDA graph
        capture as:

        ```python
        max_graph_size = min(max_num_seqs * 2, 512)
1077
1078
        # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
        # up to max_graph_size
1079
        cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
1080
            range(256, max_graph_size + 1, 16))
1081
1082

        In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
1083
        will be the final sizes to capture cudagraph (in ascending order).
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109

        These sizes are used to capture and reuse CUDA graphs for
        performance-critical paths (e.g., decoding). Capturing enables
        significantly faster kernel dispatch by avoiding Python overhead. The
        list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
        most GPUs), which controls the total allowed number of tokens in a
        batch. Since each sequence may have a variable number of tokens, the
        maximum usable batch size will depend on actual sequence lengths.

        Example:
            With `max_num_batched_tokens = 8192`, and typical sequences
            averaging ~32 tokens, most practical batch sizes fall below 256.
            However, the system will still allow capture sizes up to 512 if
            shape and memory permit.

        Note:
            If users explicitly specify cudagraph capture sizes in the
            compilation config, those will override this default logic.
            At runtime:

            - If batch size <= one of the `cudagraph_capture_sizes`, the closest
            padded CUDA graph will be used.
            - If batch size > largest `cudagraph_capture_sizes`, cudagraph will
            not be used.
        """

1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
        if (
            self.model_config is not None
            and not self.model_config.enforce_eager
            and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
            # determine the initial max_cudagraph_capture_size
            max_cudagraph_capture_size = (
                self.compilation_config.max_cudagraph_capture_size
            )
            if max_cudagraph_capture_size is None:
1120
1121
1122
1123
1124
1125
                decode_query_len = 1
                if (
                    self.speculative_config
                    and self.speculative_config.num_speculative_tokens
                ):
                    decode_query_len += self.speculative_config.num_speculative_tokens
1126
                max_cudagraph_capture_size = min(
1127
                    self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
1128
                )
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
            max_num_tokens = self.scheduler_config.max_num_batched_tokens
            max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)

            assert max_cudagraph_capture_size >= 1, (
                "Maximum cudagraph size should be greater than or equal to 1 "
                "when using cuda graph."
            )

            # determine the cudagraph_capture_sizes
            if self.compilation_config.cudagraph_capture_sizes is not None:
                assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
                    "cudagraph_capture_sizes should contain at least one element "
                    "when using cuda graph."
                )
                # de-duplicate the sizes provided by the config
                dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
1145
1146
1147
                cudagraph_capture_sizes = [
                    i for i in dedup_sizes if i <= max_num_tokens
                ]
1148
1149
                # sort to make sure the sizes are in ascending order
                cudagraph_capture_sizes.sort()
1150
            else:
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
                cudagraph_capture_sizes = [
                    i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
                ]
                if max_cudagraph_capture_size >= 8:
                    # Step size 8 for small batch sizes, up to 256(not included)
                    cudagraph_capture_sizes += list(
                        range(8, min(max_cudagraph_capture_size + 1, 256), 8)
                    )
                if max_cudagraph_capture_size >= 256:
                    # Step size 16 for larger batch sizes
                    cudagraph_capture_sizes += list(
                        range(256, max_cudagraph_capture_size + 1, 16)
                    )

1165
1166
            if (
                self.parallel_config.tensor_parallel_size > 1
1167
                and self.compilation_config.pass_config.enable_sp
1168
            ):
1169
1170
                cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
                    cudagraph_capture_sizes
1171
                )
1172

1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
            # user-specific compilation_config.max_cudagraph_capture_size get
            # truncated to valid_max_size when they are inconsistent.
            valid_max_size = (
                cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
            )
            if (
                self.compilation_config.max_cudagraph_capture_size is not None
                and self.compilation_config.max_cudagraph_capture_size != valid_max_size
            ):
                # raise error only when both two flags are user-specified
                # and they are inconsistent with each other
                if self.compilation_config.cudagraph_capture_sizes is not None:
                    raise ValueError(
                        "customized max_cudagraph_capture_size"
                        f"(={self.compilation_config.max_cudagraph_capture_size}) "
                        "should be consistent with the max value of "
                        f"cudagraph_capture_sizes(={valid_max_size})"
                    )

                logger.warning(
                    "Truncating max_cudagraph_capture_size to %d",
                    valid_max_size,
                )
            # always set the final max_cudagraph_capture_size
            self.compilation_config.max_cudagraph_capture_size = valid_max_size

            if self.compilation_config.cudagraph_capture_sizes is not None and len(
                cudagraph_capture_sizes
            ) < len(self.compilation_config.cudagraph_capture_sizes):
                # If users have specified capture sizes, we only need to
                # compare the lens before and after modification since the modified
                # list is only the subset of the original list.
                logger.warning(
                    (
                        "cudagraph_capture_sizes specified in compilation_config"
                        " %s is overridden by config %s"
                    ),
                    self.compilation_config.cudagraph_capture_sizes,
                    cudagraph_capture_sizes,
                )
            # always write back the final sizes
            self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes

        else:
            # no cudagraph in use
            self.compilation_config.max_cudagraph_capture_size = 0
            self.compilation_config.cudagraph_capture_sizes = []

        # complete the remaining process.
        self.compilation_config.post_init_cudagraph_sizes()
1223

1224
1225
1226
1227
1228
1229
1230
    def _set_compile_ranges(self):
        """
        Set the compile ranges for the compilation config.
        """
        compilation_config = self.compilation_config
        computed_compile_ranges_split_points = []

1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
        # The upper bound of the compile ranges is the max_num_batched_tokens.
        # For speculative decoding with draft model, the compile range must be extended
        # by 1 for each sequence.
        compile_range_end = self.scheduler_config.max_num_batched_tokens
        if compile_range_end is not None:
            do_extend: bool = (
                self.speculative_config is not None
                and self.speculative_config.uses_draft_model()
            )
            if do_extend:
                compile_range_end += self.scheduler_config.max_num_seqs

            computed_compile_ranges_split_points.append(compile_range_end)
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253

        # Add the compile ranges for flashinfer
        if compilation_config.pass_config.fuse_allreduce_rms:
            tp_size = self.parallel_config.tensor_parallel_size
            max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
            if max_size is not None:
                max_token_num = max_size // (
                    self.model_config.get_hidden_size()
                    * self.model_config.dtype.itemsize
                )
1254
                if compile_range_end is not None and max_token_num < compile_range_end:
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
                    computed_compile_ranges_split_points.append(max_token_num)
                else:
                    logger.debug(
                        "Max num batched tokens below allreduce-rms fusion threshold, "
                        "allreduce-rms fusion will be enabled for all num_tokens."
                    )

        if compilation_config.compile_ranges_split_points is not None:
            for x in compilation_config.compile_ranges_split_points:
                assert isinstance(x, int)
                assert x > 0, f"Invalid compile range split point: {x}"
1266
                if compile_range_end is not None and x < compile_range_end and x > 1:
1267
1268
1269
1270
1271
                    computed_compile_ranges_split_points.append(x)
        compilation_config.compile_ranges_split_points = sorted(
            computed_compile_ranges_split_points
        )

1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
    def try_verify_and_update_config(self):
        if self.model_config is None:
            return

        # Avoid running try_verify_and_update_config multiple times
        if getattr(self.model_config, "config_updated", False):
            return
        self.model_config.config_updated = True

        architecture = self.model_config.architecture
        if architecture is None:
            return

        from vllm.model_executor.models.config import (
1286
1287
1288
1289
            MODELS_CONFIG_MAP,
            HybridAttentionMambaModelConfig,
        )

1290
1291
1292
1293
1294
1295
1296
1297
1298
        cls = MODELS_CONFIG_MAP.get(architecture, None)
        if cls is not None:
            cls.verify_and_update_config(self)

        if self.model_config.is_hybrid:
            HybridAttentionMambaModelConfig.verify_and_update_config(self)

        if self.model_config.convert_type == "classify":
            # Maybe convert ForCausalLM into ForSequenceClassification model.
1299
1300
            from vllm.model_executor.models.adapters import SequenceClassificationConfig

1301
1302
1303
            SequenceClassificationConfig.verify_and_update_config(self)

        if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
1304
1305
            self.model_config.model_weights
        ):
1306
            if self.load_config.load_format == "auto":
1307
1308
1309
1310
                logger.info(
                    "Detected Run:ai model config. "
                    "Overriding `load_format` to 'runai_streamer'"
                )
1311
                self.load_config.load_format = "runai_streamer"
1312
1313
1314
1315
            elif self.load_config.load_format not in (
                "runai_streamer",
                "runai_streamer_sharded",
            ):
1316
1317
                raise ValueError(
                    f"To load a model from S3, 'load_format' "
1318
                    f"must be 'runai_streamer' or 'runai_streamer_sharded', "
1319
1320
1321
                    f"but got '{self.load_config.load_format}'. "
                    f"Model: {self.model_config.model}"
                )
1322

1323
    def compile_debug_dump_path(self) -> Path | None:
1324
        """Returns a rank-aware path for dumping
1325
1326
1327
1328
1329
        torch.compile debug information.
        """
        if self.compilation_config.debug_dump_path is None:
            return None
        tp_rank = self.parallel_config.rank
1330
1331
        dp_rank = self.parallel_config.data_parallel_index
        append_path = f"rank_{tp_rank}_dp_{dp_rank}"
1332
1333
1334
        path = self.compilation_config.debug_dump_path / append_path
        return path

1335
1336
1337
1338
1339
1340
1341
1342
    def replace(self, **kwargs):
        """
        Replace attributes of the config, and 'recompute' the config.
        dataclass.replace() calls __init__() and __post_init__(), source:
        https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
        """
        return replace(self, **kwargs)

1343
1344
1345
1346
    def __str__(self):
        return (
            f"model={self.model_config.model!r}, "
            f"speculative_config={self.speculative_config!r}, "
1347
1348
1349
            f"tokenizer={self.model_config.tokenizer!r}, "
            f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
            f"tokenizer_mode={self.model_config.tokenizer_mode}, "
1350
            f"revision={self.model_config.revision}, "
1351
            f"tokenizer_revision={self.model_config.tokenizer_revision}, "
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
            f"trust_remote_code={self.model_config.trust_remote_code}, "
            f"dtype={self.model_config.dtype}, "
            f"max_seq_len={self.model_config.max_model_len}, "
            f"download_dir={self.load_config.download_dir!r}, "
            f"load_format={self.load_config.load_format}, "
            f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, "  # noqa
            f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, "  # noqa
            f"data_parallel_size={self.parallel_config.data_parallel_size}, "  # noqa
            f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, "  # noqa
            f"quantization={self.model_config.quantization}, "
            f"enforce_eager={self.model_config.enforce_eager}, "
1363
            f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, "  # noqa
1364
1365
1366
1367
1368
1369
1370
            f"kv_cache_dtype={self.cache_config.cache_dtype}, "
            f"device_config={self.device_config.device}, "
            f"structured_outputs_config={self.structured_outputs_config!r}, "
            f"observability_config={self.observability_config!r}, "
            f"seed={self.model_config.seed}, "
            f"served_model_name={self.model_config.served_model_name}, "
            f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
1371
            f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, "  # noqa
1372
            f"pooler_config={self.model_config.pooler_config!r}, "
1373
1374
            f"compilation_config={self.compilation_config!r}"
        )
1375

1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
    @model_validator(mode="after")
    def validate_mamba_block_size(self) -> "VllmConfig":
        if self.model_config is None:
            return self
        mamba_block_size_is_set = (
            self.cache_config.mamba_block_size is not None
            and self.cache_config.mamba_block_size != self.model_config.max_model_len
        )
        if mamba_block_size_is_set and not self.cache_config.enable_prefix_caching:
            raise ValueError(
                "--mamba-block-size can only be set with --enable-prefix-caching"
            )
        return self

1390

1391
1392
_current_vllm_config: VllmConfig | None = None
_current_prefix: str | None = None
1393
1394
1395


@contextmanager
1396
def set_current_vllm_config(
1397
    vllm_config: VllmConfig, check_compile=False, prefix: str | None = None
1398
):
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
    """
    Temporarily set the current vLLM config.
    Used during model initialization.
    We save the current vLLM config in a global variable,
    so that all modules can access it, e.g. custom ops
    can access the vLLM config to determine how to dispatch.
    """
    global _current_vllm_config, _current_prefix
    old_vllm_config = _current_vllm_config
    old_prefix = _current_prefix
    from vllm.compilation.counter import compilation_counter
1410

1411
1412
    num_models_seen = compilation_counter.num_models_seen
    try:
1413
1414
1415
1416
1417
        # Clear the compilation config cache when context changes.
        # This is needed since the old config may have been accessed
        # and cached before the new config is set.
        get_cached_compilation_config.cache_clear()

1418
1419
1420
1421
1422
1423
1424
1425
1426
        _current_vllm_config = vllm_config
        _current_prefix = prefix
        yield
    except Exception:
        raise
    else:
        if check_compile:
            vllm_config.compilation_config.custom_op_log_check()

1427
1428
        if (
            check_compile
1429
            and vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
1430
1431
            and compilation_counter.num_models_seen == num_models_seen
        ):
1432
1433
1434
1435
1436
1437
1438
1439
1440
            # If the model supports compilation,
            # compilation_counter.num_models_seen should be increased
            # by at least 1.
            # If it is not increased, it means the model does not support
            # compilation (does not have @support_torch_compile decorator).
            logger.warning(
                "`torch.compile` is turned on, but the model %s"
                " does not support it. Please open an issue on GitHub"
                " if you want it to be supported.",
1441
1442
                vllm_config.model_config.model,
            )
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
    finally:
        _current_vllm_config = old_vllm_config
        _current_prefix = old_prefix
        # Clear the compilation config cache when context changes
        get_cached_compilation_config.cache_clear()


@lru_cache(maxsize=1)
def get_cached_compilation_config():
    """Cache config to avoid repeated calls to get_current_vllm_config()"""
    return get_current_vllm_config().compilation_config


def get_current_vllm_config() -> VllmConfig:
    if _current_vllm_config is None:
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
        raise AssertionError(
            "Current vLLM config is not set. This typically means "
            "get_current_vllm_config() was called outside of a "
            "set_current_vllm_config() context, or a CustomOp was instantiated "
            "at module import time or model forward time when config is not set. "
            "For tests that directly test custom ops/modules, use the "
            "'default_vllm_config' pytest fixture from tests/conftest.py."
        )
    return _current_vllm_config


def get_current_vllm_config_or_none() -> VllmConfig | None:
1470
1471
1472
1473
1474
1475
1476
    return _current_vllm_config


T = TypeVar("T")


def get_layers_from_vllm_config(
1477
1478
    vllm_config: VllmConfig,
    layer_type: type[T],
1479
    layer_names: list[str] | None = None,
1480
) -> dict[str, T]:
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
    """
    Get layers from the vLLM config.

    Args:
        vllm_config: The vLLM config.
        layer_type: The type of the layer to get.
        layer_names: The names of the layers to get. If None, return all layers.
    """

    if layer_names is None:
1491
        layer_names = list(vllm_config.compilation_config.static_forward_context.keys())
1492
1493
1494
1495
1496
1497
1498
1499

    forward_context = vllm_config.compilation_config.static_forward_context

    return {
        layer_name: forward_context[layer_name]
        for layer_name in layer_names
        if isinstance(forward_context[layer_name], layer_type)
    }