vllm.py 57.2 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
5
import getpass
6
7
import json
import os
8
9
import tempfile
import threading
10
import time
11
from contextlib import contextmanager
12
from dataclasses import is_dataclass, replace
13
from datetime import datetime
14
from enum import IntEnum
15
16
from functools import lru_cache
from pathlib import Path
17
from typing import TYPE_CHECKING, Any, TypeVar, get_args
18
19

import torch
20
from pydantic import ConfigDict, Field, model_validator
21
22
23
from pydantic.dataclasses import dataclass

import vllm.envs as envs
24
from vllm.config.speculative import EagleModelTypes
25
from vllm.logger import enable_trace_function_call, init_logger
26
27
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
28
from vllm.utils.hashing import safe_hash
29
30

from .cache import CacheConfig
31
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
32
from .device import DeviceConfig
33
from .ec_transfer import ECTransferConfig
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .parallel import ParallelConfig
from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config

if TYPE_CHECKING:
    from transformers import PretrainedConfig

49
    from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
50
    from vllm.v1.kv_cache_interface import KVCacheConfig
51
52
53
54
55
else:
    PretrainedConfig = Any

    QuantizationConfig = Any

56
57
    KVCacheConfig = Any

58
59
60
logger = init_logger(__name__)


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class OptimizationLevel(IntEnum):
    """Optimization level enum."""

    O0 = 0
    """O0 : No optimization. no compilation, no cudagraphs, no other
    optimization, just starting up immediately"""
    O1 = 1
    """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise 
    cudagraphs"""
    O2 = 2
    """O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
    O3 = 3
    """O3: Currently the same as -O2s."""


IS_QUANTIZED = False
IS_DENSE = False
# The optimizations that depend on these properties currently set to False
# in all cases.
# if model_config is not None:
#     IS_QUANTIZED = lambda c: c.model_config.is_quantized()
#     IS_DENSE = lambda c: not c.model_config.is_model_moe()
# See https://github.com/vllm-project/vllm/issues/25689.


def enable_fusion(cfg: "VllmConfig") -> bool:
    """Returns True if RMS norm or quant FP8 is enabled."""
    return cfg.compilation_config.is_custom_op_enabled(
        "rms_norm"
    ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")


OPTIMIZATION_LEVEL_00 = {
    "compilation_config": {
        "pass_config": {
            "enable_noop": False,
            "enable_fusion": False,
            "enable_fi_allreduce_fusion": False,
            "enable_attn_fusion": False,
            "enable_sequence_parallelism": False,
            "enable_async_tp": False,
        },
        "cudagraph_mode": CUDAGraphMode.NONE,
        "use_inductor_graph_partition": False,
    },
}
OPTIMIZATION_LEVEL_01 = {
    "compilation_config": {
        "pass_config": {
            "enable_noop": True,
            "enable_fusion": enable_fusion,
            "enable_fi_allreduce_fusion": False,
            "enable_attn_fusion": False,
            "enable_sequence_parallelism": False,
            "enable_async_tp": False,
        },
        "cudagraph_mode": CUDAGraphMode.PIECEWISE,
        "use_inductor_graph_partition": False,
    },
}
OPTIMIZATION_LEVEL_02 = {
    "compilation_config": {
        "pass_config": {
            "enable_noop": True,
            "enable_fusion": enable_fusion,
            "enable_fi_allreduce_fusion": False,
            "enable_attn_fusion": IS_QUANTIZED,
            "enable_sequence_parallelism": IS_DENSE,
            "enable_async_tp": IS_DENSE,
        },
        "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
        "use_inductor_graph_partition": False,
    },
}
OPTIMIZATION_LEVEL_03 = {
    "compilation_config": {
        "pass_config": {
            "enable_noop": True,
            "enable_fusion": enable_fusion,
            "enable_fi_allreduce_fusion": False,
            "enable_attn_fusion": IS_QUANTIZED,
            "enable_sequence_parallelism": IS_DENSE,
            "enable_async_tp": IS_DENSE,
        },
        "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
        "use_inductor_graph_partition": False,
    },
}

OPTIMIZATION_LEVEL_TO_CONFIG = {
    OptimizationLevel.O0: OPTIMIZATION_LEVEL_00,
    OptimizationLevel.O1: OPTIMIZATION_LEVEL_01,
    OptimizationLevel.O2: OPTIMIZATION_LEVEL_02,
    OptimizationLevel.O3: OPTIMIZATION_LEVEL_03,
}


158
159
160
161
162
163
164
165
166
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
    """Dataclass which contains all vllm-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    # TODO: use default_factory once default constructing ModelConfig doesn't
    # try to download a model
167
    model_config: ModelConfig = Field(default=None)
168
    """Model configuration."""
169
    cache_config: CacheConfig = Field(default_factory=CacheConfig)
170
    """Cache configuration."""
171
    parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
172
    """Parallel configuration."""
173
    scheduler_config: SchedulerConfig = Field(default_factory=SchedulerConfig)
174
    """Scheduler configuration."""
175
    device_config: DeviceConfig = Field(default_factory=DeviceConfig)
176
    """Device configuration."""
177
    load_config: LoadConfig = Field(default_factory=LoadConfig)
178
    """Load configuration."""
179
    lora_config: LoRAConfig | None = None
180
    """LoRA configuration."""
181
    speculative_config: SpeculativeConfig | None = None
182
    """Speculative decoding configuration."""
183
    structured_outputs_config: StructuredOutputsConfig = Field(
184
185
        default_factory=StructuredOutputsConfig
    )
186
    """Structured outputs configuration."""
187
188
189
    observability_config: ObservabilityConfig = Field(
        default_factory=ObservabilityConfig
    )
190
    """Observability configuration."""
191
    quant_config: QuantizationConfig | None = None
192
    """Quantization configuration."""
193
    compilation_config: CompilationConfig = Field(default_factory=CompilationConfig)
194
195
    """`torch.compile` and cudagraph capture configuration for the model.

196
197
    As a shorthand, one can append compilation arguments via
    -cc.parameter=argument such as `-cc.mode=3` (same as `-cc='{"mode":3}'`).
198
199

    You can specify the full compilation config like so:
200
    `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
201
    """
202
    kv_transfer_config: KVTransferConfig | None = None
203
    """The configurations for distributed KV cache transfer."""
204
    kv_events_config: KVEventsConfig | None = None
205
    """The configurations for event publishing."""
206
207
    ec_transfer_config: ECTransferConfig | None = None
    """The configurations for distributed EC cache transfer."""
208
209
210
    # some opaque config, only used to provide additional information
    # for the hash computation, mainly used for testing, debugging or out of
    # tree config registration.
211
    additional_config: dict | SupportsHash = Field(default_factory=dict)
212
213
214
215
216
    """Additional config for specified platform. Different platforms may
    support different configs. Make sure the configs are valid for the platform
    you are using. Contents must be hashable."""
    instance_id: str = ""
    """The ID of the vLLM instance."""
217
218
219
220
221
    optimization_level: OptimizationLevel = OptimizationLevel.O2
    """The optimization level. These levels trade startup time cost for
    performance, with -O0 having the best startup time and -O3 having the best
    performance. -02 is used by defult. See  OptimizationLevel for full
    description."""
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []

        # summarize vllm config
        vllm_factors: list[Any] = []
        from vllm import __version__
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        vllm_factors.append(__version__)
        if self.model_config:
            vllm_factors.append(self.model_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.cache_config:
            vllm_factors.append(self.cache_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.parallel_config:
            vllm_factors.append(self.parallel_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.scheduler_config:
            vllm_factors.append(self.scheduler_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.device_config:
            vllm_factors.append(self.device_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.load_config:
            vllm_factors.append(self.load_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.lora_config:
            vllm_factors.append(self.lora_config.compute_hash())
            # LoRA creates static buffers based on max_num_batched_tokens.
            # The tensor sizes and strides get captured in the torch.compile
            # graph explicitly.
271
            vllm_factors.append(str(self.scheduler_config.max_num_batched_tokens))
272
273
274
275
276
277
278
279
280
281
        else:
            vllm_factors.append("None")
        if self.speculative_config:
            vllm_factors.append(self.speculative_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.structured_outputs_config:
            vllm_factors.append(self.structured_outputs_config.compute_hash())
        else:
            vllm_factors.append("None")
282
        vllm_factors.append(self.observability_config.compute_hash())
283
284
285
286
287
288
289
290
291
292
        if self.quant_config:
            pass  # should be captured by model_config.quantization
        if self.compilation_config:
            vllm_factors.append(self.compilation_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.kv_transfer_config:
            vllm_factors.append(self.kv_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
293
294
295
296
        if self.ec_transfer_config:
            vllm_factors.append(self.ec_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
297
298
        if self.additional_config:
            if isinstance(additional_config := self.additional_config, dict):
299
                additional_config_hash = safe_hash(
300
301
302
303
304
305
306
307
308
309
                    json.dumps(additional_config, sort_keys=True).encode(),
                    usedforsecurity=False,
                ).hexdigest()
            else:
                additional_config_hash = additional_config.compute_hash()
            vllm_factors.append(additional_config_hash)
        else:
            vllm_factors.append("None")
        factors.append(vllm_factors)

310
311
312
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
            :10
        ]
313
314
315
        return hash_str

    def pad_for_cudagraph(self, batch_size: int) -> int:
316
        # if batch_size > self.compilation_config.max_cudagraph_capture_size,
317
318
        # it should raise an IndexError.
        # the caller should make sure the batch_size is within the range,
319
        # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
320
321
        return self.compilation_config.bs_to_padded_graph_size[batch_size]

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    def enable_trace_function_call_for_thread(self) -> None:
        """
        Set up function tracing for the current thread,
        if enabled via the `VLLM_TRACE_FUNCTION` environment variable.
        """
        if envs.VLLM_TRACE_FUNCTION:
            tmp_dir = tempfile.gettempdir()
            # add username to tmp_dir to avoid permission issues
            tmp_dir = os.path.join(tmp_dir, getpass.getuser())
            filename = (
                f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                f"_thread_{threading.get_ident()}_at_{datetime.now()}.log"
            ).replace(" ", "_")
            log_path = os.path.join(
                tmp_dir,
                "vllm",
                f"vllm-instance-{self.instance_id}",
                filename,
            )
            os.makedirs(os.path.dirname(log_path), exist_ok=True)
            enable_trace_function_call(log_path)

344
345
    @staticmethod
    def _get_quantization_config(
346
        model_config: ModelConfig, load_config: LoadConfig
347
    ) -> QuantizationConfig | None:
348
349
        """Get the quantization config."""
        from vllm.platforms import current_platform
350

351
        if model_config.quantization is not None:
352
353
            from vllm.model_executor.model_loader.weight_utils import get_quant_config

354
355
356
357
358
359
360
361
362
363
            quant_config = get_quant_config(model_config, load_config)
            capability_tuple = current_platform.get_device_capability()

            if capability_tuple is not None:
                capability = capability_tuple.to_int()
                if capability < quant_config.get_min_capability():
                    raise ValueError(
                        f"The quantization method {model_config.quantization} "
                        "is not supported for the current GPU. Minimum "
                        f"capability: {quant_config.get_min_capability()}. "
364
365
                        f"Current capability: {capability}."
                    )
366
367
368
369
370
            supported_dtypes = quant_config.get_supported_act_dtypes()
            if model_config.dtype not in supported_dtypes:
                raise ValueError(
                    f"{model_config.dtype} is not supported for quantization "
                    f"method {model_config.quantization}. Supported dtypes: "
371
372
                    f"{supported_dtypes}"
                )
373
374
375
376
377
378
            quant_config.maybe_update_config(model_config.model)
            return quant_config
        return None

    @staticmethod
    def get_quantization_config(
379
        model_config: ModelConfig, load_config: LoadConfig
380
    ) -> QuantizationConfig | None:
381
382
383
384
        import copy

        # For some reason, the _ version of this modifies the model_config
        # object, so using deepcopy to avoid this problem.
385
386
387
        return VllmConfig._get_quantization_config(
            copy.deepcopy(model_config), load_config
        )
388
389
390
391

    def with_hf_config(
        self,
        hf_config: PretrainedConfig,
392
        architectures: list[str] | None = None,
393
394
395
396
397
398
399
400
401
402
    ) -> "VllmConfig":
        if architectures is not None:
            hf_config = copy.deepcopy(hf_config)
            hf_config.architectures = architectures

        model_config = copy.deepcopy(self.model_config)
        model_config.hf_config = hf_config

        return replace(self, model_config=model_config)

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    def _set_config_default(self, config_obj: Any, key: str, value: Any) -> None:
        """Set config attribute to default if not already set by user.

        Args:
            config_obj: Configuration object to update.
            key: Attribute name.
            value: Default value (static or callable).
        """
        if getattr(config_obj, key) is None:
            # Some config values are known before initialization and are
            # hard coded.
            # Other values depend on the user given configuration, so they are
            # implemented with lambda functions and decided at run time.
            setattr(config_obj, key, value(self) if callable(value) else value)

    def _apply_optimization_level_defaults(self, defaults: dict[str, Any]) -> None:
        """Apply optimization level defaults using self as root.

        Recursively applies values from defaults into nested config objects.
        Only fields present in defaults are overwritten.

        If the user configuration does not specify a value for a default field
        and if the default field is still None after all user selections are
        applied, then default values will be applied to the field. User speciied
        fields will not be overridden by the default.

        Args:
            defaults: Dictionary of default values to apply.
        """

        def apply_recursive(config_obj: Any, config_defaults: dict[str, Any]) -> None:
            """Recursively apply defaults to config_obj, using self as root."""
            for key, value in config_defaults.items():
                if not hasattr(config_obj, key):
                    continue

                current = getattr(config_obj, key)
                if isinstance(value, dict) and is_dataclass(current):
                    apply_recursive(current, value)
                else:
                    self._set_config_default(config_obj, key, value)

        apply_recursive(self, defaults)

447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    def _post_init_kv_transfer_config(self) -> None:
        """Update KVTransferConfig based on top-level configs in VllmConfig.

        Right now, this function reads the offloading settings from
        CacheConfig and configures the KVTransferConfig accordingly.
        """
        if (kv_offloading_backend := self.cache_config.kv_offloading_backend) is None:
            return

        # If no KVTransferConfig is provided, create a default one.
        if self.kv_transfer_config is None:
            self.kv_transfer_config = KVTransferConfig()

        if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
            raise ValueError(
                "You must set kv_offloading_size when kv_offloading_backend is set."
            )
        num_kv_ranks = (
            self.parallel_config.tensor_parallel_size
            * self.parallel_config.pipeline_parallel_size
        )

        if kv_offloading_backend == "native":
            self.kv_transfer_config.kv_connector = "OffloadingConnector"
            kv_bytes_per_rank = kv_offloading_size * (1 << 30) / num_kv_ranks

            # NOTE(ApostaC): the actual calculation for num_cpu_blocks should be
            # done after the model's KV cache is initialized
            self.kv_transfer_config.kv_connector_extra_config.update(
                {"kv_bytes_per_rank": kv_bytes_per_rank, "num_cpu_blocks": 0}
            )
        elif kv_offloading_backend == "lmcache":
            self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
            kv_gb_per_rank = kv_offloading_size / num_kv_ranks
            self.kv_transfer_config.kv_connector_extra_config = {
                "lmcache.local_cpu": True,
                "lmcache.max_local_cpu_size": kv_gb_per_rank,
            }

        # This is the same for all backends
        self.kv_transfer_config.kv_role = "kv_both"

489
    def __post_init__(self):
490
        """Verify configs are valid & consistent with each other."""
491

492
493
494
        # To give each torch profile run a unique instance name.
        self.instance_id = f"{time.time_ns()}"

495
496
497
498
        self.try_verify_and_update_config()

        if self.model_config is not None:
            self.model_config.verify_with_parallel_config(self.parallel_config)
499
            self.model_config.verify_dual_chunk_attention_config(self.load_config)
500
501
502
503
504
505
506
507

        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config is not None:
            self.lora_config.verify_with_model_config(self.model_config)

        if self.quant_config is None and self.model_config is not None:
            self.quant_config = VllmConfig._get_quantization_config(
508
509
                self.model_config, self.load_config
            )
510

511
512
513
514
515
516
517
518
519
520
521
522
523
524
        executor_backend = self.parallel_config.distributed_executor_backend
        executor_supports_async_sched = executor_backend in (
            "mp",
            "uni",
            "external_launcher",
        )

        if self.scheduler_config.async_scheduling:
            # Async scheduling explicitly enabled, hard fail any incompatibilities.
            if self.parallel_config.pipeline_parallel_size > 1:
                raise ValueError(
                    "Async scheduling is not yet compatible with "
                    "pipeline_parallel_size > 1."
                )
525
526
            # Currently, async scheduling only support eagle speculative
            # decoding.
527
            if self.speculative_config is not None:
528
529
530
531
532
533
534
535
536
537
538
539
540
                if self.speculative_config.method not in get_args(EagleModelTypes):
                    raise ValueError(
                        "Currently, async scheduling is only supported "
                        "with EAGLE/MTP kind of speculative decoding"
                    )
                if self.speculative_config.disable_padded_drafter_batch:
                    raise ValueError(
                        "async scheduling for EAGLE/MTP kind of speculative "
                        "decoding is enabled, but disable_padded_drafter_batch=True "
                        "disable_padded_drafter_batch=True is not supported for "
                        "this situation now. please set "
                        "disable_padded_drafter_batch=Fasle"
                    )
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
            if not executor_supports_async_sched:
                raise ValueError(
                    "Currently, async scheduling only supports `mp`, `uni`, or "
                    "`external_launcher` distributed executor backend, but you chose "
                    f"`{executor_backend}`."
                )
        elif self.scheduler_config.async_scheduling is None:
            # Enable async scheduling unless there is an incompatible option.
            # NOTE: we won't reach here until async scheduling is enabled by default.
            if (
                self.parallel_config.pipeline_parallel_size > 1
                or self.speculative_config is not None
            ):
                logger.warning(
                    "Async scheduling is not yet supported with speculative decoding "
                    " or pipeline_parallel_size > 1 and will be disabled."
                )
                self.scheduler_config.async_scheduling = False
            elif not executor_supports_async_sched:
                logger.warning(
                    "Async scheduling will be disabled because it is not supported "
                    "with the `%s` distributed executor backend (only `mp`, `uni`, and "
                    "`external_launcher` are supported).",
                    executor_backend,
                )
                self.scheduler_config.async_scheduling = False
            else:
                self.scheduler_config.async_scheduling = True

570
        from vllm.platforms import current_platform
571
572
573

        if (
            self.model_config is not None
574
            and self.scheduler_config.enable_chunked_prefill
575
576
577
            and self.model_config.dtype == torch.float32
            and current_platform.get_device_capability() == (7, 5)
        ):
578
579
580
            logger.warning_once(
                "Turing devices tensor cores do not support float32 matmul. "
                "To workaround this limitation, vLLM will set 'ieee' input "
581
582
                "precision for chunked prefill triton kernels."
            )
583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        if (
            self.optimization_level > OptimizationLevel.O0
            and self.model_config is not None
            and self.model_config.enforce_eager
        ):
            logger.warning("Enforce eager set, overriding optimization level to -O0")
            self.optimization_level = OptimizationLevel.O0

        if self.compilation_config.backend == "eager" or (
            self.compilation_config.mode is not None
            and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
        ):
            logger.warning(
                "Inductor compilation was disabled by user settings,"
                "Optimizations settings that are only active during"
                "Inductor compilation will be ignored."
            )

        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
            if "-quant_fp8" not in custom_ops:
                custom_ops.append("+quant_fp8")

619
        if self.compilation_config.mode is None:
620
            if self.optimization_level > OptimizationLevel.O0:
621
                self.compilation_config.mode = CompilationMode.VLLM_COMPILE
622
            else:
623
                self.compilation_config.mode = CompilationMode.NONE
624
625
626
627

        if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
            if (
                self.compilation_config.backend == "inductor"
628
                and self.compilation_config.mode != CompilationMode.NONE
629
630
631
632
            ):
                self.compilation_config.custom_ops.append("none")
            else:
                self.compilation_config.custom_ops.append("all")
633

634
635
636
637
638
639
640
641
642
643
644
645
646
647
        default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
        self._apply_optimization_level_defaults(default_config)
        if (
            self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
        ):
            logger.info(
                "Cudagraph mode %s is not compatible with compilation mode %s."
                "Overriding to NONE.",
                self.compilation_config.cudagraph_mode,
                self.compilation_config.mode,
            )
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

648
649
650
        # async tp is built on top of sequence parallelism
        # and requires it to be enabled.
        if self.compilation_config.pass_config.enable_async_tp:
651
            self.compilation_config.pass_config.enable_sequence_parallelism = True
652
653
654
655
656
657
658
        if self.compilation_config.pass_config.enable_sequence_parallelism:
            if "-rms_norm" in self.compilation_config.custom_ops:
                logger.warning(
                    "RMS norm force disabled, sequence parallelism might break"
                )
            else:
                self.compilation_config.custom_ops.append("+rms_norm")
659
660

        if current_platform.support_static_graph_mode():
661
662
663
664
665
666
667
668
669
670
            # if cudagraph_mode has full cudagraphs, we need to check support
            if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
                # decode context parallel does not support full cudagraphs
                if self.parallel_config.decode_context_parallel_size > 1:
                    logger.warning_once(
                        "Decode context parallel (DCP) is enabled, which is "
                        "incompatible with full CUDA graphs. "
                        "Overriding cudagraph_mode to PIECEWISE."
                    )
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
671
672
673
674
675
676
677
678
                # prefill context parallel do not support full cudagraphs
                elif self.parallel_config.prefill_context_parallel_size > 1:
                    logger.warning_once(
                        "Prefill context parallel (PCP) is enabled, which is "
                        "incompatible with full CUDA graphs. "
                        "Overriding cudagraph_mode to PIECEWISE."
                    )
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
679
680
681
682
683
684
                elif self.model_config is not None:
                    if self.model_config.pooler_config is not None:
                        logger.warning_once(
                            "Pooling models do not support full cudagraphs. "
                            "Overriding cudagraph_mode to PIECEWISE."
                        )
685
                        self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
686
687
688
689
690
691
                    elif self.model_config.is_encoder_decoder:
                        logger.warning_once(
                            "Encoder-decoder models do not support full cudagraphs. "
                            "Overriding cudagraph_mode to PIECEWISE."
                        )
                        self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
692
693

            # disable cudagraph when enforce eager execution
694
            if self.model_config is not None and self.model_config.enforce_eager:
695
696
                logger.info("Cudagraph is disabled under eager mode")
                self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
697
698
699
                # override related settings when enforce eager
                self.compilation_config.max_cudagraph_capture_size = 0
                self.compilation_config.cudagraph_capture_sizes = []
700
            else:
701
702
703
704
705
706
707
                self.compilation_config.cudagraph_num_of_warmups = 1

            self._set_cudagraph_sizes()
        else:
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

        if self.cache_config.kv_sharing_fast_prefill:
708
709
710
711
            if (
                self.speculative_config is not None
                and self.speculative_config.use_eagle()
            ):
712
                raise ValueError(
713
714
715
                    "Fast prefill optimization for KV sharing is not "
                    "compatible with EAGLE as EAGLE requires correct logits "
                    "for all tokens while fast prefill gives incorrect logits "
716
717
                    "for prompt tokens."
                )
718
719
720

            logger.warning_once(
                "--kv-sharing-fast-prefill requires changes on model side for "
721
722
                "correctness and to realize prefill savings. "
            )
723

724
725
        if self.model_config and self.model_config.is_encoder_decoder:
            from vllm.multimodal import MULTIMODAL_REGISTRY
726

727
728
            self.scheduler_config.max_num_encoder_input_tokens = (
                MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
729
            )
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
            logger.debug(
                "Encoder-decoder model detected: setting "
                "`max_num_encoder_input_tokens` to encoder length (%s)",
                self.scheduler_config.max_num_encoder_input_tokens,
            )
            if (
                self.model_config.architecture == "WhisperForConditionalGeneration"
                and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
            ):
                logger.warning(
                    "Whisper is known to have issues with "
                    "forked workers. If startup is hanging, "
                    "try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
                    "to 'spawn'."
                )
745

746
747
748
749
750
        if (
            self.kv_events_config is not None
            and self.kv_events_config.enable_kv_cache_events
            and not self.cache_config.enable_prefix_caching
        ):
751
752
            logger.warning(
                "KV cache events are on, but prefix caching is not enabled."
753
754
755
756
757
758
759
760
761
762
763
764
765
                "Use --enable-prefix-caching to enable."
            )
        if (
            self.kv_events_config is not None
            and self.kv_events_config.publisher != "null"
            and not self.kv_events_config.enable_kv_cache_events
        ):
            logger.warning(
                "KV cache events are disabled,"
                "but the scheduler is configured to publish them."
                "Modify KVEventsConfig.enable_kv_cache_events"
                "to True to enable."
            )
766
767
        current_platform.check_and_update_config(self)

768
769
        # If DCP, ensure the block size is right.
        if self.parallel_config.decode_context_parallel_size > 1:
770
771
772
773
774
775
776
777
778
779
780
781
            if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
                self.parallel_config.cp_kv_cache_interleave_size
                != self.parallel_config.dcp_kv_cache_interleave_size
            ):
                self.parallel_config.cp_kv_cache_interleave_size = (
                    self.parallel_config.dcp_kv_cache_interleave_size
                )
                logger.warning_once(
                    "cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
                    "_interleave_size. And dcp-kv-cache-interleave-size will be "
                    "deprecated when PCP is fully supported."
                )
782
            assert (
783
                self.parallel_config.cp_kv_cache_interleave_size
784
785
                <= self.cache_config.block_size
                and self.cache_config.block_size
786
                % self.parallel_config.cp_kv_cache_interleave_size
787
788
789
                == 0
            ), (
                f"Block_size({self.cache_config.block_size}) should be greater "
790
791
                "than or equal to and divisible by cp_kv_cache_interleave_size "
                f"({self.parallel_config.cp_kv_cache_interleave_size})."
792
            )
793
794

        assert (
795
            self.parallel_config.cp_kv_cache_interleave_size == 1
796
            or self.speculative_config is None
797
        ), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
798

799
        # Do this after all the updates to compilation_config.mode
800
        if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
801
802
            self.compilation_config.set_splitting_ops_for_v1()

803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
        if self.compilation_config.pass_config.enable_sequence_parallelism:
            # With pipeline parallelism or dynamo partitioning,
            # native rms norm tracing errors due to incorrect residual shape.
            # Use custom rms norm to unblock. In the future,
            # the pass will operate on higher-level IR to avoid the issue.
            # TODO: https://github.com/vllm-project/vllm/issues/27894
            is_fullgraph = (
                self.compilation_config.use_inductor_graph_partition
                or len(self.compilation_config.splitting_ops) == 0
            )
            if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
                if "-rms_norm" not in self.compilation_config.custom_ops:
                    self.compilation_config.custom_ops.append("+rms_norm")
                else:
                    regime = (
                        "Dynamo partition"
                        if not is_fullgraph
                        else "pipeline parallelism"
                    )
                    logger.warning_once(
                        "Sequence parallelism not supported with"
                        "native rms_norm when using %s, "
                        "this will likely lead to an error.",
                        regime,
                    )

829
        # final check of cudagraph mode after all possible updates
830
        if current_platform.is_cuda_alike():
831
832
833
834
            if (
                self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                and self.model_config is not None
                and not self.model_config.disable_cascade_attn
835
                and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()  # noqa: E501
836
            ):
837
838
839
                logger.warning_once(
                    "No piecewise cudagraph for executing cascade attention."
                    " Will fall back to eager execution if a batch runs "
840
841
842
843
                    "into cascade attentions"
                )

            if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
844
845
                assert self.compilation_config.mode == CompilationMode.VLLM_COMPILE, (
                    "Compilation mode should be CompilationMode.VLLM_COMPILE "
846
                    "when cudagraph_mode piecewise cudagraphs is used, "
847
                    f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
848
                )
849
850

        if self.parallel_config.enable_dbo:
851
            a2a_backend = self.parallel_config.all2all_backend
852
853
854
            assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
                "Microbatching currently only supports the deepep_low_latency and "
                f"deepep_high_throughput all2all backend. {a2a_backend} is not "
855
856
857
                "supported. To fix use --all2all-backend=deepep_low_latency or "
                "--all2all-backend=deepep_high_throughput and install the DeepEP"
                " kernels."
858
            )
859
860
861

            if not self.model_config.disable_cascade_attn:
                self.model_config.disable_cascade_attn = True
862
                logger.warning_once("Disabling cascade attention when DBO is enabled.")
863
864
865
866

        if not self.instance_id:
            self.instance_id = random_uuid()[:5]

867
        if not self.scheduler_config.disable_hybrid_kv_cache_manager:
868
869
870
871
872
873
            # logger should only print warning message for hybrid models. As we
            # can't know whether the model is hybrid or not now, so we don't log
            # warning message here and will log it later.
            if not current_platform.support_hybrid_kv_cache():
                # Hybrid KV cache manager is not supported on non-GPU platforms.
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
874
875
876
877
878
879
880
881
882
883
884
885
886
887
            if self.kv_transfer_config is not None:
                # NOTE(Kuntai): turn HMA off for connector for now.
                # TODO(Kuntai): have a more elegent solution to check and
                # turn off HMA for connector that does not support HMA.
                logger.warning(
                    "Turning off hybrid kv cache manager because "
                    "`--kv-transfer-config` is set. This will reduce the "
                    "performance of vLLM on LLMs with sliding window attention "
                    "or Mamba attention. If you are a developer of kv connector"
                    ", please consider supporting hybrid kv cache manager for "
                    "your connector by making sure your connector is a subclass"
                    " of `SupportsHMA` defined in kv_connector/v1/base.py."
                )
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
888
889
890
            if self.kv_events_config is not None:
                # Hybrid KV cache manager is not compatible with KV events.
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
891
892
893
894
895
896
897
898
            if (
                self.model_config is not None
                and self.model_config.attention_chunk_size is not None
            ):
                if (
                    self.speculative_config is not None
                    and self.speculative_config.use_eagle()
                ):
899
900
901
                    # Hybrid KV cache manager is not yet supported with chunked
                    # local attention + eagle.
                    self.scheduler_config.disable_hybrid_kv_cache_manager = True
902
                elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
903
904
905
906
907
908
909
910
911
912
913
                    logger.warning(
                        "There is a latency regression when using chunked local"
                        " attention with the hybrid KV cache manager. Disabling"
                        " it, by default. To enable it, set the environment "
                        "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
                    )
                    # Hybrid KV cache manager is not yet supported with chunked
                    # local attention.
                    self.scheduler_config.disable_hybrid_kv_cache_manager = True

        if self.compilation_config.debug_dump_path:
914
            self.compilation_config.debug_dump_path = (
915
                self.compilation_config.debug_dump_path.absolute().expanduser()
916
            )
917
918
919
920
921
        if envs.VLLM_DEBUG_DUMP_PATH is not None:
            env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
            if self.compilation_config.debug_dump_path:
                logger.warning(
                    "Config-specified debug dump path is overridden"
922
923
924
                    " by VLLM_DEBUG_DUMP_PATH to %s",
                    env_path,
                )
925
926
            self.compilation_config.debug_dump_path = env_path

927
928
929
930
931
932
933
934
935
936
937
938
939
940
        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
941
            if "-quant_fp8" not in custom_ops:
942
943
                custom_ops.append("+quant_fp8")

944
945
946
        # Handle the KV connector configs
        self._post_init_kv_transfer_config()

947
    def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
948
949
950
        # remove the sizes that not multiple of tp_size when
        # enable sequence parallelism
        removed_sizes = [
951
952
            size
            for size in possible_sizes
953
954
955
956
957
958
            if size % self.parallel_config.tensor_parallel_size != 0
        ]
        if removed_sizes:
            logger.warning(
                "Batch sizes %s are removed because they are not "
                "multiple of tp_size %d when "
959
960
961
962
                "sequence parallelism is enabled",
                removed_sizes,
                self.parallel_config.tensor_parallel_size,
            )
963
964

        return [
965
966
            size
            for size in possible_sizes
967
968
969
970
971
972
973
974
975
976
            if size % self.parallel_config.tensor_parallel_size == 0
        ]

    def _set_cudagraph_sizes(self):
        """
        vLLM defines the default candidate list of batch sizes for CUDA graph
        capture as:

        ```python
        max_graph_size = min(max_num_seqs * 2, 512)
977
978
979
980
        # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
        # up to max_graph_size
        cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
            range(256, max_graph_size + 1, 16))
981
982

        In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
983
        will be the final sizes to capture cudagraph (in ascending order).
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009

        These sizes are used to capture and reuse CUDA graphs for
        performance-critical paths (e.g., decoding). Capturing enables
        significantly faster kernel dispatch by avoiding Python overhead. The
        list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
        most GPUs), which controls the total allowed number of tokens in a
        batch. Since each sequence may have a variable number of tokens, the
        maximum usable batch size will depend on actual sequence lengths.

        Example:
            With `max_num_batched_tokens = 8192`, and typical sequences
            averaging ~32 tokens, most practical batch sizes fall below 256.
            However, the system will still allow capture sizes up to 512 if
            shape and memory permit.

        Note:
            If users explicitly specify cudagraph capture sizes in the
            compilation config, those will override this default logic.
            At runtime:

            - If batch size <= one of the `cudagraph_capture_sizes`, the closest
            padded CUDA graph will be used.
            - If batch size > largest `cudagraph_capture_sizes`, cudagraph will
            not be used.
        """

1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
        if (
            self.model_config is not None
            and not self.model_config.enforce_eager
            and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
            # determine the initial max_cudagraph_capture_size
            max_cudagraph_capture_size = (
                self.compilation_config.max_cudagraph_capture_size
            )
            if max_cudagraph_capture_size is None:
                max_cudagraph_capture_size = min(
                    self.scheduler_config.max_num_seqs * 2, 512
1022
                )
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
            max_num_tokens = self.scheduler_config.max_num_batched_tokens
            max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)

            assert max_cudagraph_capture_size >= 1, (
                "Maximum cudagraph size should be greater than or equal to 1 "
                "when using cuda graph."
            )

            # determine the cudagraph_capture_sizes
            if self.compilation_config.cudagraph_capture_sizes is not None:
                assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
                    "cudagraph_capture_sizes should contain at least one element "
                    "when using cuda graph."
                )
                # de-duplicate the sizes provided by the config
                dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
1039
1040
1041
                cudagraph_capture_sizes = [
                    i for i in dedup_sizes if i <= max_num_tokens
                ]
1042
1043
                # sort to make sure the sizes are in ascending order
                cudagraph_capture_sizes.sort()
1044
            else:
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
                cudagraph_capture_sizes = [
                    i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
                ]
                if max_cudagraph_capture_size >= 8:
                    # Step size 8 for small batch sizes, up to 256(not included)
                    cudagraph_capture_sizes += list(
                        range(8, min(max_cudagraph_capture_size + 1, 256), 8)
                    )
                if max_cudagraph_capture_size >= 256:
                    # Step size 16 for larger batch sizes
                    cudagraph_capture_sizes += list(
                        range(256, max_cudagraph_capture_size + 1, 16)
                    )

1059
1060
1061
1062
            if (
                self.parallel_config.tensor_parallel_size > 1
                and self.compilation_config.pass_config.enable_sequence_parallelism
            ):
1063
1064
                cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
                    cudagraph_capture_sizes
1065
                )
1066

1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
            # user-specific compilation_config.max_cudagraph_capture_size get
            # truncated to valid_max_size when they are inconsistent.
            valid_max_size = (
                cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
            )
            if (
                self.compilation_config.max_cudagraph_capture_size is not None
                and self.compilation_config.max_cudagraph_capture_size != valid_max_size
            ):
                # raise error only when both two flags are user-specified
                # and they are inconsistent with each other
                if self.compilation_config.cudagraph_capture_sizes is not None:
                    raise ValueError(
                        "customized max_cudagraph_capture_size"
                        f"(={self.compilation_config.max_cudagraph_capture_size}) "
                        "should be consistent with the max value of "
                        f"cudagraph_capture_sizes(={valid_max_size})"
                    )

                logger.warning(
                    "Truncating max_cudagraph_capture_size to %d",
                    valid_max_size,
                )
            # always set the final max_cudagraph_capture_size
            self.compilation_config.max_cudagraph_capture_size = valid_max_size

            if self.compilation_config.cudagraph_capture_sizes is not None and len(
                cudagraph_capture_sizes
            ) < len(self.compilation_config.cudagraph_capture_sizes):
                # If users have specified capture sizes, we only need to
                # compare the lens before and after modification since the modified
                # list is only the subset of the original list.
                logger.warning(
                    (
                        "cudagraph_capture_sizes specified in compilation_config"
                        " %s is overridden by config %s"
                    ),
                    self.compilation_config.cudagraph_capture_sizes,
                    cudagraph_capture_sizes,
                )
            # always write back the final sizes
            self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes

        else:
            # no cudagraph in use
            self.compilation_config.max_cudagraph_capture_size = 0
            self.compilation_config.cudagraph_capture_sizes = []

        # complete the remaining process.
        self.compilation_config.post_init_cudagraph_sizes()
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137

    def recalculate_max_model_len(self, max_model_len: int):
        # Can only be called in try_verify_and_update_config
        model_config = self.model_config
        max_model_len = model_config.get_and_verify_max_len(max_model_len)
        self.model_config.max_model_len = max_model_len

    def try_verify_and_update_config(self):
        if self.model_config is None:
            return

        # Avoid running try_verify_and_update_config multiple times
        if getattr(self.model_config, "config_updated", False):
            return
        self.model_config.config_updated = True

        architecture = self.model_config.architecture
        if architecture is None:
            return

        from vllm.model_executor.models.config import (
1138
1139
1140
1141
            MODELS_CONFIG_MAP,
            HybridAttentionMambaModelConfig,
        )

1142
1143
1144
1145
1146
1147
1148
1149
1150
        cls = MODELS_CONFIG_MAP.get(architecture, None)
        if cls is not None:
            cls.verify_and_update_config(self)

        if self.model_config.is_hybrid:
            HybridAttentionMambaModelConfig.verify_and_update_config(self)

        if self.model_config.convert_type == "classify":
            # Maybe convert ForCausalLM into ForSequenceClassification model.
1151
1152
            from vllm.model_executor.models.adapters import SequenceClassificationConfig

1153
1154
1155
            SequenceClassificationConfig.verify_and_update_config(self)

        if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
1156
1157
            self.model_config.model_weights
        ):
1158
            if self.load_config.load_format == "auto":
1159
1160
1161
1162
                logger.info(
                    "Detected Run:ai model config. "
                    "Overriding `load_format` to 'runai_streamer'"
                )
1163
                self.load_config.load_format = "runai_streamer"
1164
1165
1166
1167
            elif self.load_config.load_format not in (
                "runai_streamer",
                "runai_streamer_sharded",
            ):
1168
1169
                raise ValueError(
                    f"To load a model from S3, 'load_format' "
1170
                    f"must be 'runai_streamer' or 'runai_streamer_sharded', "
1171
1172
1173
                    f"but got '{self.load_config.load_format}'. "
                    f"Model: {self.model_config.model}"
                )
1174

1175
    def compile_debug_dump_path(self) -> Path | None:
1176
        """Returns a rank-aware path for dumping
1177
1178
1179
1180
1181
1182
1183
        torch.compile debug information.
        """
        if self.compilation_config.debug_dump_path is None:
            return None
        tp_rank = self.parallel_config.rank
        dp_rank = self.parallel_config.data_parallel_rank
        data_parallel_size = self.parallel_config.data_parallel_size
1184
1185
1186
        append_path = (
            f"rank_{tp_rank}"
            if data_parallel_size == 1
1187
            else f"rank_{tp_rank}_dp_{dp_rank}"
1188
        )
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
        path = self.compilation_config.debug_dump_path / append_path
        return path

    def __str__(self):
        return (
            f"model={self.model_config.model!r}, "
            f"speculative_config={self.speculative_config!r}, "
            f"tokenizer={self.model_config.tokenizer!r}, "
            f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
            f"tokenizer_mode={self.model_config.tokenizer_mode}, "
            f"revision={self.model_config.revision}, "
            f"tokenizer_revision={self.model_config.tokenizer_revision}, "
            f"trust_remote_code={self.model_config.trust_remote_code}, "
            f"dtype={self.model_config.dtype}, "
            f"max_seq_len={self.model_config.max_model_len}, "
            f"download_dir={self.load_config.download_dir!r}, "
            f"load_format={self.load_config.load_format}, "
            f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, "  # noqa
            f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, "  # noqa
            f"data_parallel_size={self.parallel_config.data_parallel_size}, "  # noqa
            f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, "  # noqa
            f"quantization={self.model_config.quantization}, "
            f"enforce_eager={self.model_config.enforce_eager}, "
            f"kv_cache_dtype={self.cache_config.cache_dtype}, "
            f"device_config={self.device_config.device}, "
            f"structured_outputs_config={self.structured_outputs_config!r}, "
            f"observability_config={self.observability_config!r}, "
            f"seed={self.model_config.seed}, "
            f"served_model_name={self.model_config.served_model_name}, "
            f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
1219
            f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, "  # noqa
1220
            f"pooler_config={self.model_config.pooler_config!r}, "
1221
1222
            f"compilation_config={self.compilation_config!r}"
        )
1223

1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
    @model_validator(mode="after")
    def validate_mamba_block_size(self) -> "VllmConfig":
        if self.model_config is None:
            return self
        mamba_block_size_is_set = (
            self.cache_config.mamba_block_size is not None
            and self.cache_config.mamba_block_size != self.model_config.max_model_len
        )
        if mamba_block_size_is_set and not self.cache_config.enable_prefix_caching:
            raise ValueError(
                "--mamba-block-size can only be set with --enable-prefix-caching"
            )
        return self

1238

1239
1240
_current_vllm_config: VllmConfig | None = None
_current_prefix: str | None = None
1241
1242
1243


@contextmanager
1244
def set_current_vllm_config(
1245
    vllm_config: VllmConfig, check_compile=False, prefix: str | None = None
1246
):
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
    """
    Temporarily set the current vLLM config.
    Used during model initialization.
    We save the current vLLM config in a global variable,
    so that all modules can access it, e.g. custom ops
    can access the vLLM config to determine how to dispatch.
    """
    global _current_vllm_config, _current_prefix
    old_vllm_config = _current_vllm_config
    old_prefix = _current_prefix
    from vllm.compilation.counter import compilation_counter
1258

1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
    num_models_seen = compilation_counter.num_models_seen
    try:
        _current_vllm_config = vllm_config
        _current_prefix = prefix
        yield
    except Exception:
        raise
    else:
        if check_compile:
            vllm_config.compilation_config.custom_op_log_check()

1270
1271
        if (
            check_compile
1272
            and vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
1273
1274
            and compilation_counter.num_models_seen == num_models_seen
        ):
1275
1276
1277
1278
1279
1280
1281
1282
1283
            # If the model supports compilation,
            # compilation_counter.num_models_seen should be increased
            # by at least 1.
            # If it is not increased, it means the model does not support
            # compilation (does not have @support_torch_compile decorator).
            logger.warning(
                "`torch.compile` is turned on, but the model %s"
                " does not support it. Please open an issue on GitHub"
                " if you want it to be supported.",
1284
1285
                vllm_config.model_config.model,
            )
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
    finally:
        _current_vllm_config = old_vllm_config
        _current_prefix = old_prefix
        # Clear the compilation config cache when context changes
        get_cached_compilation_config.cache_clear()


@lru_cache(maxsize=1)
def get_cached_compilation_config():
    """Cache config to avoid repeated calls to get_current_vllm_config()"""
    return get_current_vllm_config().compilation_config


def get_current_vllm_config() -> VllmConfig:
    if _current_vllm_config is None:
        # in ci, usually when we test custom ops/modules directly,
        # we don't set the vllm config. In that case, we set a default
        # config.
        logger.warning("Current vLLM config is not set.")
        return VllmConfig()
    return _current_vllm_config


T = TypeVar("T")


def get_layers_from_vllm_config(
1313
1314
    vllm_config: VllmConfig,
    layer_type: type[T],
1315
    layer_names: list[str] | None = None,
1316
) -> dict[str, T]:
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
    """
    Get layers from the vLLM config.

    Args:
        vllm_config: The vLLM config.
        layer_type: The type of the layer to get.
        layer_names: The names of the layers to get. If None, return all layers.
    """

    if layer_names is None:
1327
        layer_names = list(vllm_config.compilation_config.static_forward_context.keys())
1328
1329
1330
1331
1332
1333
1334
1335

    forward_context = vllm_config.compilation_config.static_forward_context

    return {
        layer_name: forward_context[layer_name]
        for layer_name in layer_names
        if isinstance(forward_context[layer_name], layer_type)
    }