vllm.py 56.5 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
import getpass
6
7
import json
import os
8
9
import tempfile
import threading
10
import time
11
from contextlib import contextmanager
12
from dataclasses import is_dataclass, replace
13
from datetime import datetime
14
from enum import IntEnum
15
16
from functools import lru_cache
from pathlib import Path
17
from typing import TYPE_CHECKING, Any, TypeVar, get_args
18
19

import torch
20
from pydantic import ConfigDict, Field, model_validator
21
22
23
from pydantic.dataclasses import dataclass

import vllm.envs as envs
24
from vllm.config.speculative import EagleModelTypes
25
from vllm.logger import enable_trace_function_call, init_logger
26
27
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
28
from vllm.utils.hashing import safe_hash
29
30

from .cache import CacheConfig
31
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
32
from .device import DeviceConfig
33
from .ec_transfer import ECTransferConfig
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .parallel import ParallelConfig
from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config

if TYPE_CHECKING:
    from transformers import PretrainedConfig

49
    from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
50
    from vllm.v1.kv_cache_interface import KVCacheConfig
51
52
53
54
55
else:
    PretrainedConfig = Any

    QuantizationConfig = Any

56
57
    KVCacheConfig = Any

58
59
60
logger = init_logger(__name__)


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class OptimizationLevel(IntEnum):
    """Optimization level enum."""

    O0 = 0
    """O0 : No optimization. no compilation, no cudagraphs, no other
    optimization, just starting up immediately"""
    O1 = 1
    """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise 
    cudagraphs"""
    O2 = 2
    """O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
    O3 = 3
    """O3: Currently the same as -O2s."""


IS_QUANTIZED = False
IS_DENSE = False
# The optimizations that depend on these properties currently set to False
# in all cases.
# if model_config is not None:
#     IS_QUANTIZED = lambda c: c.model_config.is_quantized()
#     IS_DENSE = lambda c: not c.model_config.is_model_moe()
# See https://github.com/vllm-project/vllm/issues/25689.


86
87
88
89
def enable_norm_fusion(cfg: "VllmConfig") -> bool:
    """Enable if either RMS norm or quant FP8 custom op is active;
    otherwise Inductor handles fusion."""

90
91
92
93
94
    return cfg.compilation_config.is_custom_op_enabled(
        "rms_norm"
    ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")


95
96
97
98
99
100
101
102
def enable_act_fusion(cfg: "VllmConfig") -> bool:
    """Enable if either SiLU+Mul or quant FP8 custom op is active;
    otherwise Inductor handles fusion."""
    return cfg.compilation_config.is_custom_op_enabled(
        "silu_and_mul"
    ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")


103
104
105
OPTIMIZATION_LEVEL_00 = {
    "compilation_config": {
        "pass_config": {
106
107
108
109
110
111
112
            "eliminate_noops": False,
            "fuse_norm_quant": False,
            "fuse_act_quant": False,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": False,
            "enable_sp": False,
            "fuse_gemm_comms": False,
113
114
115
116
117
118
119
120
        },
        "cudagraph_mode": CUDAGraphMode.NONE,
        "use_inductor_graph_partition": False,
    },
}
OPTIMIZATION_LEVEL_01 = {
    "compilation_config": {
        "pass_config": {
121
122
123
124
125
126
127
            "eliminate_noops": True,
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": False,
            "enable_sp": False,
            "fuse_gemm_comms": False,
128
129
130
131
132
133
134
135
        },
        "cudagraph_mode": CUDAGraphMode.PIECEWISE,
        "use_inductor_graph_partition": False,
    },
}
OPTIMIZATION_LEVEL_02 = {
    "compilation_config": {
        "pass_config": {
136
137
138
139
140
141
142
            "eliminate_noops": True,
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": IS_QUANTIZED,
            "enable_sp": IS_DENSE,
            "fuse_gemm_comms": IS_DENSE,
143
144
145
146
147
148
149
150
        },
        "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
        "use_inductor_graph_partition": False,
    },
}
OPTIMIZATION_LEVEL_03 = {
    "compilation_config": {
        "pass_config": {
151
152
153
154
155
156
157
            "eliminate_noops": True,
            "fuse_norm_quant": enable_norm_fusion,
            "fuse_act_quant": enable_act_fusion,
            "fuse_allreduce_rms": False,
            "fuse_attn_quant": IS_QUANTIZED,
            "enable_sp": IS_DENSE,
            "fuse_gemm_comms": IS_DENSE,
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        },
        "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
        "use_inductor_graph_partition": False,
    },
}

OPTIMIZATION_LEVEL_TO_CONFIG = {
    OptimizationLevel.O0: OPTIMIZATION_LEVEL_00,
    OptimizationLevel.O1: OPTIMIZATION_LEVEL_01,
    OptimizationLevel.O2: OPTIMIZATION_LEVEL_02,
    OptimizationLevel.O3: OPTIMIZATION_LEVEL_03,
}


172
173
174
175
176
177
178
179
180
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
    """Dataclass which contains all vllm-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    # TODO: use default_factory once default constructing ModelConfig doesn't
    # try to download a model
181
    model_config: ModelConfig = Field(default=None)
182
    """Model configuration."""
183
    cache_config: CacheConfig = Field(default_factory=CacheConfig)
184
    """Cache configuration."""
185
    parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
186
    """Parallel configuration."""
187
188
189
    scheduler_config: SchedulerConfig = Field(
        default_factory=SchedulerConfig.default_factory,
    )
190
    """Scheduler configuration."""
191
    device_config: DeviceConfig = Field(default_factory=DeviceConfig)
192
    """Device configuration."""
193
    load_config: LoadConfig = Field(default_factory=LoadConfig)
194
    """Load configuration."""
195
    lora_config: LoRAConfig | None = None
196
    """LoRA configuration."""
197
    speculative_config: SpeculativeConfig | None = None
198
    """Speculative decoding configuration."""
199
    structured_outputs_config: StructuredOutputsConfig = Field(
200
201
        default_factory=StructuredOutputsConfig
    )
202
    """Structured outputs configuration."""
203
204
205
    observability_config: ObservabilityConfig = Field(
        default_factory=ObservabilityConfig
    )
206
    """Observability configuration."""
207
    quant_config: QuantizationConfig | None = None
208
    """Quantization configuration."""
209
    compilation_config: CompilationConfig = Field(default_factory=CompilationConfig)
210
211
    """`torch.compile` and cudagraph capture configuration for the model.

212
213
    As a shorthand, one can append compilation arguments via
    -cc.parameter=argument such as `-cc.mode=3` (same as `-cc='{"mode":3}'`).
214
215

    You can specify the full compilation config like so:
216
    `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
217
    """
218
    kv_transfer_config: KVTransferConfig | None = None
219
    """The configurations for distributed KV cache transfer."""
220
    kv_events_config: KVEventsConfig | None = None
221
    """The configurations for event publishing."""
222
223
    ec_transfer_config: ECTransferConfig | None = None
    """The configurations for distributed EC cache transfer."""
224
225
226
    # some opaque config, only used to provide additional information
    # for the hash computation, mainly used for testing, debugging or out of
    # tree config registration.
227
    additional_config: dict | SupportsHash = Field(default_factory=dict)
228
229
230
231
232
    """Additional config for specified platform. Different platforms may
    support different configs. Make sure the configs are valid for the platform
    you are using. Contents must be hashable."""
    instance_id: str = ""
    """The ID of the vLLM instance."""
233
234
235
236
237
    optimization_level: OptimizationLevel = OptimizationLevel.O2
    """The optimization level. These levels trade startup time cost for
    performance, with -O0 having the best startup time and -O3 having the best
    performance. -02 is used by defult. See  OptimizationLevel for full
    description."""
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []

        # summarize vllm config
        vllm_factors: list[Any] = []
        from vllm import __version__
256

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        vllm_factors.append(__version__)
        if self.model_config:
            vllm_factors.append(self.model_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.cache_config:
            vllm_factors.append(self.cache_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.parallel_config:
            vllm_factors.append(self.parallel_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.scheduler_config:
            vllm_factors.append(self.scheduler_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.device_config:
            vllm_factors.append(self.device_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.load_config:
            vllm_factors.append(self.load_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.lora_config:
            vllm_factors.append(self.lora_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.speculative_config:
            vllm_factors.append(self.speculative_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.structured_outputs_config:
            vllm_factors.append(self.structured_outputs_config.compute_hash())
        else:
            vllm_factors.append("None")
294
        vllm_factors.append(self.observability_config.compute_hash())
295
296
297
298
299
300
301
302
303
304
        if self.quant_config:
            pass  # should be captured by model_config.quantization
        if self.compilation_config:
            vllm_factors.append(self.compilation_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.kv_transfer_config:
            vllm_factors.append(self.kv_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
305
306
307
308
        if self.ec_transfer_config:
            vllm_factors.append(self.ec_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
309
310
        if self.additional_config:
            if isinstance(additional_config := self.additional_config, dict):
311
                additional_config_hash = safe_hash(
312
313
314
315
316
317
318
319
320
321
                    json.dumps(additional_config, sort_keys=True).encode(),
                    usedforsecurity=False,
                ).hexdigest()
            else:
                additional_config_hash = additional_config.compute_hash()
            vllm_factors.append(additional_config_hash)
        else:
            vllm_factors.append("None")
        factors.append(vllm_factors)

322
323
324
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
            :10
        ]
325
326
327
        return hash_str

    def pad_for_cudagraph(self, batch_size: int) -> int:
328
        # if batch_size > self.compilation_config.max_cudagraph_capture_size,
329
330
        # it should raise an IndexError.
        # the caller should make sure the batch_size is within the range,
331
        # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
332
333
        return self.compilation_config.bs_to_padded_graph_size[batch_size]

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    def enable_trace_function_call_for_thread(self) -> None:
        """
        Set up function tracing for the current thread,
        if enabled via the `VLLM_TRACE_FUNCTION` environment variable.
        """
        if envs.VLLM_TRACE_FUNCTION:
            tmp_dir = tempfile.gettempdir()
            # add username to tmp_dir to avoid permission issues
            tmp_dir = os.path.join(tmp_dir, getpass.getuser())
            filename = (
                f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                f"_thread_{threading.get_ident()}_at_{datetime.now()}.log"
            ).replace(" ", "_")
            log_path = os.path.join(
                tmp_dir,
                "vllm",
                f"vllm-instance-{self.instance_id}",
                filename,
            )
            os.makedirs(os.path.dirname(log_path), exist_ok=True)
            enable_trace_function_call(log_path)

356
357
    @staticmethod
    def _get_quantization_config(
358
        model_config: ModelConfig, load_config: LoadConfig
359
    ) -> QuantizationConfig | None:
360
361
        """Get the quantization config."""
        from vllm.platforms import current_platform
362

363
        if model_config.quantization is not None:
364
365
            from vllm.model_executor.model_loader.weight_utils import get_quant_config

366
367
368
369
370
371
372
373
374
375
            quant_config = get_quant_config(model_config, load_config)
            capability_tuple = current_platform.get_device_capability()

            if capability_tuple is not None:
                capability = capability_tuple.to_int()
                if capability < quant_config.get_min_capability():
                    raise ValueError(
                        f"The quantization method {model_config.quantization} "
                        "is not supported for the current GPU. Minimum "
                        f"capability: {quant_config.get_min_capability()}. "
376
377
                        f"Current capability: {capability}."
                    )
378
379
380
381
382
            supported_dtypes = quant_config.get_supported_act_dtypes()
            if model_config.dtype not in supported_dtypes:
                raise ValueError(
                    f"{model_config.dtype} is not supported for quantization "
                    f"method {model_config.quantization}. Supported dtypes: "
383
384
                    f"{supported_dtypes}"
                )
385
386
387
388
389
390
            quant_config.maybe_update_config(model_config.model)
            return quant_config
        return None

    @staticmethod
    def get_quantization_config(
391
        model_config: ModelConfig, load_config: LoadConfig
392
    ) -> QuantizationConfig | None:
393
394
395
396
        import copy

        # For some reason, the _ version of this modifies the model_config
        # object, so using deepcopy to avoid this problem.
397
398
399
        return VllmConfig._get_quantization_config(
            copy.deepcopy(model_config), load_config
        )
400
401
402
403

    def with_hf_config(
        self,
        hf_config: PretrainedConfig,
404
        architectures: list[str] | None = None,
405
406
407
408
409
410
411
412
413
414
    ) -> "VllmConfig":
        if architectures is not None:
            hf_config = copy.deepcopy(hf_config)
            hf_config.architectures = architectures

        model_config = copy.deepcopy(self.model_config)
        model_config.hf_config = hf_config

        return replace(self, model_config=model_config)

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    def _set_config_default(self, config_obj: Any, key: str, value: Any) -> None:
        """Set config attribute to default if not already set by user.

        Args:
            config_obj: Configuration object to update.
            key: Attribute name.
            value: Default value (static or callable).
        """
        if getattr(config_obj, key) is None:
            # Some config values are known before initialization and are
            # hard coded.
            # Other values depend on the user given configuration, so they are
            # implemented with lambda functions and decided at run time.
            setattr(config_obj, key, value(self) if callable(value) else value)

    def _apply_optimization_level_defaults(self, defaults: dict[str, Any]) -> None:
        """Apply optimization level defaults using self as root.

        Recursively applies values from defaults into nested config objects.
        Only fields present in defaults are overwritten.

        If the user configuration does not specify a value for a default field
        and if the default field is still None after all user selections are
        applied, then default values will be applied to the field. User speciied
        fields will not be overridden by the default.

        Args:
            defaults: Dictionary of default values to apply.
        """

        def apply_recursive(config_obj: Any, config_defaults: dict[str, Any]) -> None:
            """Recursively apply defaults to config_obj, using self as root."""
            for key, value in config_defaults.items():
                if not hasattr(config_obj, key):
                    continue

                current = getattr(config_obj, key)
                if isinstance(value, dict) and is_dataclass(current):
                    apply_recursive(current, value)
                else:
                    self._set_config_default(config_obj, key, value)

        apply_recursive(self, defaults)

459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    def _post_init_kv_transfer_config(self) -> None:
        """Update KVTransferConfig based on top-level configs in VllmConfig.

        Right now, this function reads the offloading settings from
        CacheConfig and configures the KVTransferConfig accordingly.
        """
        if (kv_offloading_backend := self.cache_config.kv_offloading_backend) is None:
            return

        # If no KVTransferConfig is provided, create a default one.
        if self.kv_transfer_config is None:
            self.kv_transfer_config = KVTransferConfig()

        if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
            raise ValueError(
                "You must set kv_offloading_size when kv_offloading_backend is set."
            )
        num_kv_ranks = (
            self.parallel_config.tensor_parallel_size
            * self.parallel_config.pipeline_parallel_size
        )

        if kv_offloading_backend == "native":
            self.kv_transfer_config.kv_connector = "OffloadingConnector"
            kv_bytes_per_rank = kv_offloading_size * (1 << 30) / num_kv_ranks

            # NOTE(ApostaC): the actual calculation for num_cpu_blocks should be
            # done after the model's KV cache is initialized
            self.kv_transfer_config.kv_connector_extra_config.update(
                {"kv_bytes_per_rank": kv_bytes_per_rank, "num_cpu_blocks": 0}
            )
        elif kv_offloading_backend == "lmcache":
            self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
            kv_gb_per_rank = kv_offloading_size / num_kv_ranks
            self.kv_transfer_config.kv_connector_extra_config = {
                "lmcache.local_cpu": True,
                "lmcache.max_local_cpu_size": kv_gb_per_rank,
            }

        # This is the same for all backends
        self.kv_transfer_config.kv_role = "kv_both"

501
    def __post_init__(self):
502
        """Verify configs are valid & consistent with each other."""
503

504
505
506
        # To give each torch profile run a unique instance name.
        self.instance_id = f"{time.time_ns()}"

507
508
509
510
        self.try_verify_and_update_config()

        if self.model_config is not None:
            self.model_config.verify_with_parallel_config(self.parallel_config)
511
            self.model_config.verify_dual_chunk_attention_config(self.load_config)
512
513
514
515
516
517
518
519

        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config is not None:
            self.lora_config.verify_with_model_config(self.model_config)

        if self.quant_config is None and self.model_config is not None:
            self.quant_config = VllmConfig._get_quantization_config(
520
521
                self.model_config, self.load_config
            )
522

523
524
525
526
527
528
529
530
531
532
533
534
535
536
        executor_backend = self.parallel_config.distributed_executor_backend
        executor_supports_async_sched = executor_backend in (
            "mp",
            "uni",
            "external_launcher",
        )

        if self.scheduler_config.async_scheduling:
            # Async scheduling explicitly enabled, hard fail any incompatibilities.
            if self.parallel_config.pipeline_parallel_size > 1:
                raise ValueError(
                    "Async scheduling is not yet compatible with "
                    "pipeline_parallel_size > 1."
                )
537
538
            # Currently, async scheduling only support eagle speculative
            # decoding.
539
            if self.speculative_config is not None:
540
541
542
543
544
545
546
547
548
549
550
551
552
                if self.speculative_config.method not in get_args(EagleModelTypes):
                    raise ValueError(
                        "Currently, async scheduling is only supported "
                        "with EAGLE/MTP kind of speculative decoding"
                    )
                if self.speculative_config.disable_padded_drafter_batch:
                    raise ValueError(
                        "async scheduling for EAGLE/MTP kind of speculative "
                        "decoding is enabled, but disable_padded_drafter_batch=True "
                        "disable_padded_drafter_batch=True is not supported for "
                        "this situation now. please set "
                        "disable_padded_drafter_batch=Fasle"
                    )
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
            if not executor_supports_async_sched:
                raise ValueError(
                    "Currently, async scheduling only supports `mp`, `uni`, or "
                    "`external_launcher` distributed executor backend, but you chose "
                    f"`{executor_backend}`."
                )
        elif self.scheduler_config.async_scheduling is None:
            # Enable async scheduling unless there is an incompatible option.
            # NOTE: we won't reach here until async scheduling is enabled by default.
            if (
                self.parallel_config.pipeline_parallel_size > 1
                or self.speculative_config is not None
            ):
                logger.warning(
                    "Async scheduling is not yet supported with speculative decoding "
                    " or pipeline_parallel_size > 1 and will be disabled."
                )
                self.scheduler_config.async_scheduling = False
            elif not executor_supports_async_sched:
                logger.warning(
                    "Async scheduling will be disabled because it is not supported "
                    "with the `%s` distributed executor backend (only `mp`, `uni`, and "
                    "`external_launcher` are supported).",
                    executor_backend,
                )
                self.scheduler_config.async_scheduling = False
            else:
                self.scheduler_config.async_scheduling = True

582
        from vllm.platforms import current_platform
583
584
585

        if (
            self.model_config is not None
586
            and self.scheduler_config.enable_chunked_prefill
587
588
589
            and self.model_config.dtype == torch.float32
            and current_platform.get_device_capability() == (7, 5)
        ):
590
591
592
            logger.warning_once(
                "Turing devices tensor cores do not support float32 matmul. "
                "To workaround this limitation, vLLM will set 'ieee' input "
593
594
                "precision for chunked prefill triton kernels."
            )
595

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        if (
            self.optimization_level > OptimizationLevel.O0
            and self.model_config is not None
            and self.model_config.enforce_eager
        ):
            logger.warning("Enforce eager set, overriding optimization level to -O0")
            self.optimization_level = OptimizationLevel.O0

        if self.compilation_config.backend == "eager" or (
            self.compilation_config.mode is not None
            and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
        ):
            logger.warning(
                "Inductor compilation was disabled by user settings,"
                "Optimizations settings that are only active during"
                "Inductor compilation will be ignored."
            )

        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
            if "-quant_fp8" not in custom_ops:
                custom_ops.append("+quant_fp8")

631
        if self.compilation_config.mode is None:
632
            if self.optimization_level > OptimizationLevel.O0:
633
                self.compilation_config.mode = CompilationMode.VLLM_COMPILE
634
            else:
635
                self.compilation_config.mode = CompilationMode.NONE
636
637
638
639

        if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
            if (
                self.compilation_config.backend == "inductor"
640
                and self.compilation_config.mode != CompilationMode.NONE
641
642
643
644
            ):
                self.compilation_config.custom_ops.append("none")
            else:
                self.compilation_config.custom_ops.append("all")
645

646
647
648
649
650
651
652
653
654
655
656
657
658
659
        default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
        self._apply_optimization_level_defaults(default_config)
        if (
            self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
        ):
            logger.info(
                "Cudagraph mode %s is not compatible with compilation mode %s."
                "Overriding to NONE.",
                self.compilation_config.cudagraph_mode,
                self.compilation_config.mode,
            )
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

660
661
        # async tp is built on top of sequence parallelism
        # and requires it to be enabled.
662
663
664
        if self.compilation_config.pass_config.fuse_gemm_comms:
            self.compilation_config.pass_config.enable_sp = True
        if self.compilation_config.pass_config.enable_sp:
665
666
667
668
669
670
            if "-rms_norm" in self.compilation_config.custom_ops:
                logger.warning(
                    "RMS norm force disabled, sequence parallelism might break"
                )
            else:
                self.compilation_config.custom_ops.append("+rms_norm")
671
672

        if current_platform.support_static_graph_mode():
673
            # if cudagraph_mode has full cudagraphs, we need to check support
674
675
676
677
678
            if (
                self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                and self.model_config is not None
            ):
                if self.model_config.pooler_config is not None:
679
                    logger.warning_once(
680
                        "Pooling models do not support full cudagraphs. "
681
682
683
                        "Overriding cudagraph_mode to PIECEWISE."
                    )
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
684
                elif self.model_config.is_encoder_decoder:
685
                    logger.warning_once(
686
                        "Encoder-decoder models do not support full cudagraphs. "
687
688
689
                        "Overriding cudagraph_mode to PIECEWISE."
                    )
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
690
691

            # disable cudagraph when enforce eager execution
692
            if self.model_config is not None and self.model_config.enforce_eager:
693
694
                logger.info("Cudagraph is disabled under eager mode")
                self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
695
696
697
                # override related settings when enforce eager
                self.compilation_config.max_cudagraph_capture_size = 0
                self.compilation_config.cudagraph_capture_sizes = []
698
            else:
699
700
701
702
703
704
705
                self.compilation_config.cudagraph_num_of_warmups = 1

            self._set_cudagraph_sizes()
        else:
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

        if self.cache_config.kv_sharing_fast_prefill:
706
707
708
709
            if (
                self.speculative_config is not None
                and self.speculative_config.use_eagle()
            ):
710
                raise ValueError(
711
712
713
                    "Fast prefill optimization for KV sharing is not "
                    "compatible with EAGLE as EAGLE requires correct logits "
                    "for all tokens while fast prefill gives incorrect logits "
714
715
                    "for prompt tokens."
                )
716
717
718

            logger.warning_once(
                "--kv-sharing-fast-prefill requires changes on model side for "
719
720
                "correctness and to realize prefill savings. "
            )
721

722
723
        if self.model_config and self.model_config.is_encoder_decoder:
            from vllm.multimodal import MULTIMODAL_REGISTRY
724

725
726
            self.scheduler_config.max_num_encoder_input_tokens = (
                MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
727
            )
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
            logger.debug(
                "Encoder-decoder model detected: setting "
                "`max_num_encoder_input_tokens` to encoder length (%s)",
                self.scheduler_config.max_num_encoder_input_tokens,
            )
            if (
                self.model_config.architecture == "WhisperForConditionalGeneration"
                and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
            ):
                logger.warning(
                    "Whisper is known to have issues with "
                    "forked workers. If startup is hanging, "
                    "try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
                    "to 'spawn'."
                )
743

744
745
746
747
748
        if (
            self.kv_events_config is not None
            and self.kv_events_config.enable_kv_cache_events
            and not self.cache_config.enable_prefix_caching
        ):
749
750
            logger.warning(
                "KV cache events are on, but prefix caching is not enabled."
751
752
753
754
755
756
757
758
759
760
761
762
763
                "Use --enable-prefix-caching to enable."
            )
        if (
            self.kv_events_config is not None
            and self.kv_events_config.publisher != "null"
            and not self.kv_events_config.enable_kv_cache_events
        ):
            logger.warning(
                "KV cache events are disabled,"
                "but the scheduler is configured to publish them."
                "Modify KVEventsConfig.enable_kv_cache_events"
                "to True to enable."
            )
764
765
        current_platform.check_and_update_config(self)

766
767
        # If DCP, ensure the block size is right.
        if self.parallel_config.decode_context_parallel_size > 1:
768
769
770
771
772
773
774
775
776
777
778
779
            if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
                self.parallel_config.cp_kv_cache_interleave_size
                != self.parallel_config.dcp_kv_cache_interleave_size
            ):
                self.parallel_config.cp_kv_cache_interleave_size = (
                    self.parallel_config.dcp_kv_cache_interleave_size
                )
                logger.warning_once(
                    "cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
                    "_interleave_size. And dcp-kv-cache-interleave-size will be "
                    "deprecated when PCP is fully supported."
                )
780
            assert (
781
                self.parallel_config.cp_kv_cache_interleave_size
782
783
                <= self.cache_config.block_size
                and self.cache_config.block_size
784
                % self.parallel_config.cp_kv_cache_interleave_size
785
786
787
                == 0
            ), (
                f"Block_size({self.cache_config.block_size}) should be greater "
788
789
                "than or equal to and divisible by cp_kv_cache_interleave_size "
                f"({self.parallel_config.cp_kv_cache_interleave_size})."
790
            )
791
792

        assert (
793
            self.parallel_config.cp_kv_cache_interleave_size == 1
794
            or self.speculative_config is None
795
        ), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
796

797
        # Do this after all the updates to compilation_config.mode
798
        self.compilation_config.set_splitting_ops_for_v1()
799

800
        if self.compilation_config.pass_config.enable_sp:
801
802
803
804
805
            # With pipeline parallelism or dynamo partitioning,
            # native rms norm tracing errors due to incorrect residual shape.
            # Use custom rms norm to unblock. In the future,
            # the pass will operate on higher-level IR to avoid the issue.
            # TODO: https://github.com/vllm-project/vllm/issues/27894
806
807
808
809
810
811
812
            if self.compilation_config.mode != CompilationMode.VLLM_COMPILE:
                logger.warning(
                    "Sequence parallelism is enabled, but running in wrong "
                    "vllm compile mode: %s.",
                    self.compilation_config.mode,
                )

813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
            is_fullgraph = (
                self.compilation_config.use_inductor_graph_partition
                or len(self.compilation_config.splitting_ops) == 0
            )
            if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
                if "-rms_norm" not in self.compilation_config.custom_ops:
                    self.compilation_config.custom_ops.append("+rms_norm")
                else:
                    regime = (
                        "Dynamo partition"
                        if not is_fullgraph
                        else "pipeline parallelism"
                    )
                    logger.warning_once(
                        "Sequence parallelism not supported with"
                        "native rms_norm when using %s, "
                        "this will likely lead to an error.",
                        regime,
                    )

833
        # final check of cudagraph mode after all possible updates
834
        if current_platform.is_cuda_alike():
835
836
837
838
            if (
                self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                and self.model_config is not None
                and not self.model_config.disable_cascade_attn
839
                and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()  # noqa: E501
840
            ):
841
842
843
                logger.warning_once(
                    "No piecewise cudagraph for executing cascade attention."
                    " Will fall back to eager execution if a batch runs "
844
845
846
847
                    "into cascade attentions"
                )

            if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
848
849
                assert self.compilation_config.mode == CompilationMode.VLLM_COMPILE, (
                    "Compilation mode should be CompilationMode.VLLM_COMPILE "
850
                    "when cudagraph_mode piecewise cudagraphs is used, "
851
                    f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
852
                )
853
854

        if self.parallel_config.enable_dbo:
855
            a2a_backend = self.parallel_config.all2all_backend
856
857
858
            assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
                "Microbatching currently only supports the deepep_low_latency and "
                f"deepep_high_throughput all2all backend. {a2a_backend} is not "
859
860
861
                "supported. To fix use --all2all-backend=deepep_low_latency or "
                "--all2all-backend=deepep_high_throughput and install the DeepEP"
                " kernels."
862
            )
863
864
865

            if not self.model_config.disable_cascade_attn:
                self.model_config.disable_cascade_attn = True
866
                logger.warning_once("Disabling cascade attention when DBO is enabled.")
867
868
869
870

        if not self.instance_id:
            self.instance_id = random_uuid()[:5]

871
        if not self.scheduler_config.disable_hybrid_kv_cache_manager:
872
873
874
875
876
877
            # logger should only print warning message for hybrid models. As we
            # can't know whether the model is hybrid or not now, so we don't log
            # warning message here and will log it later.
            if not current_platform.support_hybrid_kv_cache():
                # Hybrid KV cache manager is not supported on non-GPU platforms.
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
878
879
880
881
882
883
884
885
886
887
888
889
890
891
            if self.kv_transfer_config is not None:
                # NOTE(Kuntai): turn HMA off for connector for now.
                # TODO(Kuntai): have a more elegent solution to check and
                # turn off HMA for connector that does not support HMA.
                logger.warning(
                    "Turning off hybrid kv cache manager because "
                    "`--kv-transfer-config` is set. This will reduce the "
                    "performance of vLLM on LLMs with sliding window attention "
                    "or Mamba attention. If you are a developer of kv connector"
                    ", please consider supporting hybrid kv cache manager for "
                    "your connector by making sure your connector is a subclass"
                    " of `SupportsHMA` defined in kv_connector/v1/base.py."
                )
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
892
893
894
            if self.kv_events_config is not None:
                # Hybrid KV cache manager is not compatible with KV events.
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
895
896
897
898
899
900
901
902
            if (
                self.model_config is not None
                and self.model_config.attention_chunk_size is not None
            ):
                if (
                    self.speculative_config is not None
                    and self.speculative_config.use_eagle()
                ):
903
904
905
                    # Hybrid KV cache manager is not yet supported with chunked
                    # local attention + eagle.
                    self.scheduler_config.disable_hybrid_kv_cache_manager = True
906
                elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
907
908
909
910
911
912
913
914
915
916
917
                    logger.warning(
                        "There is a latency regression when using chunked local"
                        " attention with the hybrid KV cache manager. Disabling"
                        " it, by default. To enable it, set the environment "
                        "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
                    )
                    # Hybrid KV cache manager is not yet supported with chunked
                    # local attention.
                    self.scheduler_config.disable_hybrid_kv_cache_manager = True

        if self.compilation_config.debug_dump_path:
918
            self.compilation_config.debug_dump_path = (
919
                self.compilation_config.debug_dump_path.absolute().expanduser()
920
            )
921
922
923
924
925
        if envs.VLLM_DEBUG_DUMP_PATH is not None:
            env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
            if self.compilation_config.debug_dump_path:
                logger.warning(
                    "Config-specified debug dump path is overridden"
926
927
928
                    " by VLLM_DEBUG_DUMP_PATH to %s",
                    env_path,
                )
929
930
            self.compilation_config.debug_dump_path = env_path

931
932
933
934
935
936
937
938
939
940
941
942
943
944
        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
945
            if "-quant_fp8" not in custom_ops:
946
947
                custom_ops.append("+quant_fp8")

948
949
950
        # Handle the KV connector configs
        self._post_init_kv_transfer_config()

951
    def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
952
953
954
        # remove the sizes that not multiple of tp_size when
        # enable sequence parallelism
        removed_sizes = [
955
956
            size
            for size in possible_sizes
957
958
959
960
961
962
            if size % self.parallel_config.tensor_parallel_size != 0
        ]
        if removed_sizes:
            logger.warning(
                "Batch sizes %s are removed because they are not "
                "multiple of tp_size %d when "
963
964
965
966
                "sequence parallelism is enabled",
                removed_sizes,
                self.parallel_config.tensor_parallel_size,
            )
967
968

        return [
969
970
            size
            for size in possible_sizes
971
972
973
974
975
976
977
978
979
980
            if size % self.parallel_config.tensor_parallel_size == 0
        ]

    def _set_cudagraph_sizes(self):
        """
        vLLM defines the default candidate list of batch sizes for CUDA graph
        capture as:

        ```python
        max_graph_size = min(max_num_seqs * 2, 512)
981
982
983
984
        # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
        # up to max_graph_size
        cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
            range(256, max_graph_size + 1, 16))
985
986

        In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
987
        will be the final sizes to capture cudagraph (in ascending order).
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013

        These sizes are used to capture and reuse CUDA graphs for
        performance-critical paths (e.g., decoding). Capturing enables
        significantly faster kernel dispatch by avoiding Python overhead. The
        list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
        most GPUs), which controls the total allowed number of tokens in a
        batch. Since each sequence may have a variable number of tokens, the
        maximum usable batch size will depend on actual sequence lengths.

        Example:
            With `max_num_batched_tokens = 8192`, and typical sequences
            averaging ~32 tokens, most practical batch sizes fall below 256.
            However, the system will still allow capture sizes up to 512 if
            shape and memory permit.

        Note:
            If users explicitly specify cudagraph capture sizes in the
            compilation config, those will override this default logic.
            At runtime:

            - If batch size <= one of the `cudagraph_capture_sizes`, the closest
            padded CUDA graph will be used.
            - If batch size > largest `cudagraph_capture_sizes`, cudagraph will
            not be used.
        """

1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
        if (
            self.model_config is not None
            and not self.model_config.enforce_eager
            and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
            # determine the initial max_cudagraph_capture_size
            max_cudagraph_capture_size = (
                self.compilation_config.max_cudagraph_capture_size
            )
            if max_cudagraph_capture_size is None:
                max_cudagraph_capture_size = min(
                    self.scheduler_config.max_num_seqs * 2, 512
1026
                )
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
            max_num_tokens = self.scheduler_config.max_num_batched_tokens
            max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)

            assert max_cudagraph_capture_size >= 1, (
                "Maximum cudagraph size should be greater than or equal to 1 "
                "when using cuda graph."
            )

            # determine the cudagraph_capture_sizes
            if self.compilation_config.cudagraph_capture_sizes is not None:
                assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
                    "cudagraph_capture_sizes should contain at least one element "
                    "when using cuda graph."
                )
                # de-duplicate the sizes provided by the config
                dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
1043
1044
1045
                cudagraph_capture_sizes = [
                    i for i in dedup_sizes if i <= max_num_tokens
                ]
1046
1047
                # sort to make sure the sizes are in ascending order
                cudagraph_capture_sizes.sort()
1048
            else:
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
                cudagraph_capture_sizes = [
                    i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
                ]
                if max_cudagraph_capture_size >= 8:
                    # Step size 8 for small batch sizes, up to 256(not included)
                    cudagraph_capture_sizes += list(
                        range(8, min(max_cudagraph_capture_size + 1, 256), 8)
                    )
                if max_cudagraph_capture_size >= 256:
                    # Step size 16 for larger batch sizes
                    cudagraph_capture_sizes += list(
                        range(256, max_cudagraph_capture_size + 1, 16)
                    )

1063
1064
            if (
                self.parallel_config.tensor_parallel_size > 1
1065
                and self.compilation_config.pass_config.enable_sp
1066
            ):
1067
1068
                cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
                    cudagraph_capture_sizes
1069
                )
1070

1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
            # user-specific compilation_config.max_cudagraph_capture_size get
            # truncated to valid_max_size when they are inconsistent.
            valid_max_size = (
                cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
            )
            if (
                self.compilation_config.max_cudagraph_capture_size is not None
                and self.compilation_config.max_cudagraph_capture_size != valid_max_size
            ):
                # raise error only when both two flags are user-specified
                # and they are inconsistent with each other
                if self.compilation_config.cudagraph_capture_sizes is not None:
                    raise ValueError(
                        "customized max_cudagraph_capture_size"
                        f"(={self.compilation_config.max_cudagraph_capture_size}) "
                        "should be consistent with the max value of "
                        f"cudagraph_capture_sizes(={valid_max_size})"
                    )

                logger.warning(
                    "Truncating max_cudagraph_capture_size to %d",
                    valid_max_size,
                )
            # always set the final max_cudagraph_capture_size
            self.compilation_config.max_cudagraph_capture_size = valid_max_size

            if self.compilation_config.cudagraph_capture_sizes is not None and len(
                cudagraph_capture_sizes
            ) < len(self.compilation_config.cudagraph_capture_sizes):
                # If users have specified capture sizes, we only need to
                # compare the lens before and after modification since the modified
                # list is only the subset of the original list.
                logger.warning(
                    (
                        "cudagraph_capture_sizes specified in compilation_config"
                        " %s is overridden by config %s"
                    ),
                    self.compilation_config.cudagraph_capture_sizes,
                    cudagraph_capture_sizes,
                )
            # always write back the final sizes
            self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes

        else:
            # no cudagraph in use
            self.compilation_config.max_cudagraph_capture_size = 0
            self.compilation_config.cudagraph_capture_sizes = []

        # complete the remaining process.
        self.compilation_config.post_init_cudagraph_sizes()
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141

    def recalculate_max_model_len(self, max_model_len: int):
        # Can only be called in try_verify_and_update_config
        model_config = self.model_config
        max_model_len = model_config.get_and_verify_max_len(max_model_len)
        self.model_config.max_model_len = max_model_len

    def try_verify_and_update_config(self):
        if self.model_config is None:
            return

        # Avoid running try_verify_and_update_config multiple times
        if getattr(self.model_config, "config_updated", False):
            return
        self.model_config.config_updated = True

        architecture = self.model_config.architecture
        if architecture is None:
            return

        from vllm.model_executor.models.config import (
1142
1143
1144
1145
            MODELS_CONFIG_MAP,
            HybridAttentionMambaModelConfig,
        )

1146
1147
1148
1149
1150
1151
1152
1153
1154
        cls = MODELS_CONFIG_MAP.get(architecture, None)
        if cls is not None:
            cls.verify_and_update_config(self)

        if self.model_config.is_hybrid:
            HybridAttentionMambaModelConfig.verify_and_update_config(self)

        if self.model_config.convert_type == "classify":
            # Maybe convert ForCausalLM into ForSequenceClassification model.
1155
1156
            from vllm.model_executor.models.adapters import SequenceClassificationConfig

1157
1158
1159
            SequenceClassificationConfig.verify_and_update_config(self)

        if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
1160
1161
            self.model_config.model_weights
        ):
1162
            if self.load_config.load_format == "auto":
1163
1164
1165
1166
                logger.info(
                    "Detected Run:ai model config. "
                    "Overriding `load_format` to 'runai_streamer'"
                )
1167
                self.load_config.load_format = "runai_streamer"
1168
1169
1170
1171
            elif self.load_config.load_format not in (
                "runai_streamer",
                "runai_streamer_sharded",
            ):
1172
1173
                raise ValueError(
                    f"To load a model from S3, 'load_format' "
1174
                    f"must be 'runai_streamer' or 'runai_streamer_sharded', "
1175
1176
1177
                    f"but got '{self.load_config.load_format}'. "
                    f"Model: {self.model_config.model}"
                )
1178

1179
    def compile_debug_dump_path(self) -> Path | None:
1180
        """Returns a rank-aware path for dumping
1181
1182
1183
1184
1185
1186
1187
        torch.compile debug information.
        """
        if self.compilation_config.debug_dump_path is None:
            return None
        tp_rank = self.parallel_config.rank
        dp_rank = self.parallel_config.data_parallel_rank
        data_parallel_size = self.parallel_config.data_parallel_size
1188
1189
1190
        append_path = (
            f"rank_{tp_rank}"
            if data_parallel_size == 1
1191
            else f"rank_{tp_rank}_dp_{dp_rank}"
1192
        )
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
        path = self.compilation_config.debug_dump_path / append_path
        return path

    def __str__(self):
        return (
            f"model={self.model_config.model!r}, "
            f"speculative_config={self.speculative_config!r}, "
            f"tokenizer={self.model_config.tokenizer!r}, "
            f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
            f"tokenizer_mode={self.model_config.tokenizer_mode}, "
            f"revision={self.model_config.revision}, "
            f"tokenizer_revision={self.model_config.tokenizer_revision}, "
            f"trust_remote_code={self.model_config.trust_remote_code}, "
            f"dtype={self.model_config.dtype}, "
            f"max_seq_len={self.model_config.max_model_len}, "
            f"download_dir={self.load_config.download_dir!r}, "
            f"load_format={self.load_config.load_format}, "
            f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, "  # noqa
            f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, "  # noqa
            f"data_parallel_size={self.parallel_config.data_parallel_size}, "  # noqa
            f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, "  # noqa
            f"quantization={self.model_config.quantization}, "
            f"enforce_eager={self.model_config.enforce_eager}, "
            f"kv_cache_dtype={self.cache_config.cache_dtype}, "
            f"device_config={self.device_config.device}, "
            f"structured_outputs_config={self.structured_outputs_config!r}, "
            f"observability_config={self.observability_config!r}, "
            f"seed={self.model_config.seed}, "
            f"served_model_name={self.model_config.served_model_name}, "
            f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
1223
            f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, "  # noqa
1224
            f"pooler_config={self.model_config.pooler_config!r}, "
1225
1226
            f"compilation_config={self.compilation_config!r}"
        )
1227

1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
    @model_validator(mode="after")
    def validate_mamba_block_size(self) -> "VllmConfig":
        if self.model_config is None:
            return self
        mamba_block_size_is_set = (
            self.cache_config.mamba_block_size is not None
            and self.cache_config.mamba_block_size != self.model_config.max_model_len
        )
        if mamba_block_size_is_set and not self.cache_config.enable_prefix_caching:
            raise ValueError(
                "--mamba-block-size can only be set with --enable-prefix-caching"
            )
        return self

1242

1243
1244
_current_vllm_config: VllmConfig | None = None
_current_prefix: str | None = None
1245
1246
1247


@contextmanager
1248
def set_current_vllm_config(
1249
    vllm_config: VllmConfig, check_compile=False, prefix: str | None = None
1250
):
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
    """
    Temporarily set the current vLLM config.
    Used during model initialization.
    We save the current vLLM config in a global variable,
    so that all modules can access it, e.g. custom ops
    can access the vLLM config to determine how to dispatch.
    """
    global _current_vllm_config, _current_prefix
    old_vllm_config = _current_vllm_config
    old_prefix = _current_prefix
    from vllm.compilation.counter import compilation_counter
1262

1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
    num_models_seen = compilation_counter.num_models_seen
    try:
        _current_vllm_config = vllm_config
        _current_prefix = prefix
        yield
    except Exception:
        raise
    else:
        if check_compile:
            vllm_config.compilation_config.custom_op_log_check()

1274
1275
        if (
            check_compile
1276
            and vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
1277
1278
            and compilation_counter.num_models_seen == num_models_seen
        ):
1279
1280
1281
1282
1283
1284
1285
1286
1287
            # If the model supports compilation,
            # compilation_counter.num_models_seen should be increased
            # by at least 1.
            # If it is not increased, it means the model does not support
            # compilation (does not have @support_torch_compile decorator).
            logger.warning(
                "`torch.compile` is turned on, but the model %s"
                " does not support it. Please open an issue on GitHub"
                " if you want it to be supported.",
1288
1289
                vllm_config.model_config.model,
            )
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
    finally:
        _current_vllm_config = old_vllm_config
        _current_prefix = old_prefix
        # Clear the compilation config cache when context changes
        get_cached_compilation_config.cache_clear()


@lru_cache(maxsize=1)
def get_cached_compilation_config():
    """Cache config to avoid repeated calls to get_current_vllm_config()"""
    return get_current_vllm_config().compilation_config


def get_current_vllm_config() -> VllmConfig:
    if _current_vllm_config is None:
        # in ci, usually when we test custom ops/modules directly,
        # we don't set the vllm config. In that case, we set a default
        # config.
        logger.warning("Current vLLM config is not set.")
        return VllmConfig()
    return _current_vllm_config


T = TypeVar("T")


def get_layers_from_vllm_config(
1317
1318
    vllm_config: VllmConfig,
    layer_type: type[T],
1319
    layer_names: list[str] | None = None,
1320
) -> dict[str, T]:
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
    """
    Get layers from the vLLM config.

    Args:
        vllm_config: The vLLM config.
        layer_type: The type of the layer to get.
        layer_names: The names of the layers to get. If None, return all layers.
    """

    if layer_names is None:
1331
        layer_names = list(vllm_config.compilation_config.static_forward_context.keys())
1332
1333
1334
1335
1336
1337
1338
1339

    forward_context = vllm_config.compilation_config.static_forward_context

    return {
        layer_name: forward_context[layer_name]
        for layer_name in layer_names
        if isinstance(forward_context[layer_name], layer_type)
    }