vllm.py 43.4 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
import hashlib
import json
import os
8
import time
9
from contextlib import contextmanager
10
from dataclasses import replace
11
12
from functools import lru_cache
from pathlib import Path
13
from typing import TYPE_CHECKING, Any, TypeVar
14
15

import torch
16
from pydantic import ConfigDict, Field
17
18
19
20
21
22
23
24
from pydantic.dataclasses import dataclass

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid

from .cache import CacheConfig
25
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from .device import DeviceConfig
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .parallel import ParallelConfig
from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config

if TYPE_CHECKING:
    from transformers import PretrainedConfig

42
    from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
else:
    PretrainedConfig = Any

    QuantizationConfig = Any

logger = init_logger(__name__)


@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
    """Dataclass which contains all vllm-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    # TODO: use default_factory once default constructing ModelConfig doesn't
    # try to download a model
60
    model_config: ModelConfig = Field(default=None)
61
    """Model configuration."""
62
    cache_config: CacheConfig = Field(default_factory=CacheConfig)
63
    """Cache configuration."""
64
    parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
65
    """Parallel configuration."""
66
    scheduler_config: SchedulerConfig = Field(default_factory=SchedulerConfig)
67
    """Scheduler configuration."""
68
    device_config: DeviceConfig = Field(default_factory=DeviceConfig)
69
    """Device configuration."""
70
    load_config: LoadConfig = Field(default_factory=LoadConfig)
71
    """Load configuration."""
72
    lora_config: LoRAConfig | None = None
73
    """LoRA configuration."""
74
    speculative_config: SpeculativeConfig | None = None
75
    """Speculative decoding configuration."""
76
    structured_outputs_config: StructuredOutputsConfig = Field(
77
78
        default_factory=StructuredOutputsConfig
    )
79
    """Structured outputs configuration."""
80
    observability_config: ObservabilityConfig | None = None
81
    """Observability configuration."""
82
    quant_config: QuantizationConfig | None = None
83
    """Quantization configuration."""
84
    compilation_config: CompilationConfig = Field(default_factory=CompilationConfig)
85
86
    """`torch.compile` and cudagraph capture configuration for the model.

87
88
    As a shorthand, one can append compilation arguments via 
    -0.parameter=arguement such as `-O.mode=3` (same as `-O='{"mode":3}'`).
89
90

    You can specify the full compilation config like so:
91
    `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
92
    """
93
    kv_transfer_config: KVTransferConfig | None = None
94
    """The configurations for distributed KV cache transfer."""
95
    kv_events_config: KVEventsConfig | None = None
96
97
98
99
    """The configurations for event publishing."""
    # some opaque config, only used to provide additional information
    # for the hash computation, mainly used for testing, debugging or out of
    # tree config registration.
100
    additional_config: dict | SupportsHash = Field(default_factory=dict)
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    """Additional config for specified platform. Different platforms may
    support different configs. Make sure the configs are valid for the platform
    you are using. Contents must be hashable."""
    instance_id: str = ""
    """The ID of the vLLM instance."""

    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []

        # summarize vllm config
        vllm_factors: list[Any] = []
        from vllm import __version__
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        vllm_factors.append(__version__)
        vllm_factors.append(envs.VLLM_USE_V1)
        if self.model_config:
            vllm_factors.append(self.model_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.cache_config:
            vllm_factors.append(self.cache_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.parallel_config:
            vllm_factors.append(self.parallel_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.scheduler_config:
            vllm_factors.append(self.scheduler_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.device_config:
            vllm_factors.append(self.device_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.load_config:
            vllm_factors.append(self.load_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.lora_config:
            vllm_factors.append(self.lora_config.compute_hash())
            # LoRA creates static buffers based on max_num_batched_tokens.
            # The tensor sizes and strides get captured in the torch.compile
            # graph explicitly.
156
            vllm_factors.append(str(self.scheduler_config.max_num_batched_tokens))
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        else:
            vllm_factors.append("None")
        if self.speculative_config:
            vllm_factors.append(self.speculative_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.structured_outputs_config:
            vllm_factors.append(self.structured_outputs_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.observability_config:
            vllm_factors.append(self.observability_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.quant_config:
            pass  # should be captured by model_config.quantization
        if self.compilation_config:
            vllm_factors.append(self.compilation_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.kv_transfer_config:
            vllm_factors.append(self.kv_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.additional_config:
            if isinstance(additional_config := self.additional_config, dict):
                additional_config_hash = hashlib.md5(
                    json.dumps(additional_config, sort_keys=True).encode(),
                    usedforsecurity=False,
                ).hexdigest()
            else:
                additional_config_hash = additional_config.compute_hash()
            vllm_factors.append(additional_config_hash)
        else:
            vllm_factors.append("None")
        factors.append(vllm_factors)

194
195
196
        hash_str = hashlib.md5(
            str(factors).encode(), usedforsecurity=False
        ).hexdigest()[:10]
197
198
199
        return hash_str

    def pad_for_cudagraph(self, batch_size: int) -> int:
200
        # if batch_size > self.compilation_config.max_cudagraph_capture_size,
201
202
        # it should raise an IndexError.
        # the caller should make sure the batch_size is within the range,
203
        # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
204
205
206
207
        return self.compilation_config.bs_to_padded_graph_size[batch_size]

    @staticmethod
    def _get_quantization_config(
208
        model_config: ModelConfig, load_config: LoadConfig
209
    ) -> QuantizationConfig | None:
210
211
        """Get the quantization config."""
        from vllm.platforms import current_platform
212

213
        if model_config.quantization is not None:
214
215
            from vllm.model_executor.model_loader.weight_utils import get_quant_config

216
217
218
219
220
221
222
223
224
225
            quant_config = get_quant_config(model_config, load_config)
            capability_tuple = current_platform.get_device_capability()

            if capability_tuple is not None:
                capability = capability_tuple.to_int()
                if capability < quant_config.get_min_capability():
                    raise ValueError(
                        f"The quantization method {model_config.quantization} "
                        "is not supported for the current GPU. Minimum "
                        f"capability: {quant_config.get_min_capability()}. "
226
227
                        f"Current capability: {capability}."
                    )
228
229
230
231
232
            supported_dtypes = quant_config.get_supported_act_dtypes()
            if model_config.dtype not in supported_dtypes:
                raise ValueError(
                    f"{model_config.dtype} is not supported for quantization "
                    f"method {model_config.quantization}. Supported dtypes: "
233
234
                    f"{supported_dtypes}"
                )
235
236
237
238
239
240
            quant_config.maybe_update_config(model_config.model)
            return quant_config
        return None

    @staticmethod
    def get_quantization_config(
241
        model_config: ModelConfig, load_config: LoadConfig
242
    ) -> QuantizationConfig | None:
243
244
245
246
        import copy

        # For some reason, the _ version of this modifies the model_config
        # object, so using deepcopy to avoid this problem.
247
248
249
        return VllmConfig._get_quantization_config(
            copy.deepcopy(model_config), load_config
        )
250
251
252
253

    def with_hf_config(
        self,
        hf_config: PretrainedConfig,
254
        architectures: list[str] | None = None,
255
256
257
258
259
260
261
262
263
264
265
    ) -> "VllmConfig":
        if architectures is not None:
            hf_config = copy.deepcopy(hf_config)
            hf_config.architectures = architectures

        model_config = copy.deepcopy(self.model_config)
        model_config.hf_config = hf_config

        return replace(self, model_config=model_config)

    def __post_init__(self):
266
        """Verify configs are valid & consistent with each other."""
267

268
269
270
        # To give each torch profile run a unique instance name.
        self.instance_id = f"{time.time_ns()}"

271
272
273
274
        self.try_verify_and_update_config()

        if self.model_config is not None:
            self.model_config.verify_with_parallel_config(self.parallel_config)
275
            self.model_config.verify_dual_chunk_attention_config(self.load_config)
276
277
278
279
280
281
282
283
284

        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config is not None:
            self.lora_config.verify_with_cache_config(self.cache_config)
            self.lora_config.verify_with_model_config(self.model_config)

        if self.quant_config is None and self.model_config is not None:
            self.quant_config = VllmConfig._get_quantization_config(
285
286
                self.model_config, self.load_config
            )
287
288

        from vllm.platforms import current_platform
289
290
291
292
293
294
295

        if (
            self.model_config is not None
            and self.scheduler_config.chunked_prefill_enabled
            and self.model_config.dtype == torch.float32
            and current_platform.get_device_capability() == (7, 5)
        ):
296
297
298
            logger.warning_once(
                "Turing devices tensor cores do not support float32 matmul. "
                "To workaround this limitation, vLLM will set 'ieee' input "
299
300
                "precision for chunked prefill triton kernels."
            )
301

302
303
        # If the user does not explicitly set a compilation mode, then
        # we use the default mode. The default mode depends on other
304
        # settings (see the below code).
305
        if self.compilation_config.mode is None:
306
            if envs.VLLM_USE_V1:
307
308
309
310
                if (
                    self.model_config is not None
                    and not self.model_config.enforce_eager
                ):
311
                    self.compilation_config.mode = CompilationMode.VLLM_COMPILE
312
                else:
313
                    self.compilation_config.mode = CompilationMode.NONE
314
315

            else:
316
317
318
                # NB: Passing both --enforce-eager and a compilation mode
                # in V0 means the compilation mode wins out.
                self.compilation_config.mode = CompilationMode.NONE
319
        else:
320
321
            assert self.compilation_config.mode >= CompilationMode.NONE
            assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE
322
323

        # If user does not set custom ops via none or all set it here based on
324
        # compilation mode and backend.
325
326
327
        if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
            if (
                self.compilation_config.backend == "inductor"
328
                and self.compilation_config.mode > CompilationMode.NONE
329
330
331
332
            ):
                self.compilation_config.custom_ops.append("none")
            else:
                self.compilation_config.custom_ops.append("all")
333
334
335
336

        # async tp is built on top of sequence parallelism
        # and requires it to be enabled.
        if self.compilation_config.pass_config.enable_async_tp:
337
            self.compilation_config.pass_config.enable_sequence_parallelism = True
338
339
340
341
342
343
344
        if self.compilation_config.pass_config.enable_sequence_parallelism:
            self.compilation_config.custom_ops.append("+rms_norm")

        if current_platform.support_static_graph_mode():
            # if cudagraph_mode is not explicitly set by users, set default
            # value
            if self.compilation_config.cudagraph_mode is None:
345
346
                if (
                    envs.VLLM_USE_V1
347
                    and self.compilation_config.mode == CompilationMode.VLLM_COMPILE
348
                ):
349
                    # default to full and piecewise for most models
350
                    self.compilation_config.cudagraph_mode = (
351
                        CUDAGraphMode.FULL_AND_PIECEWISE
352
                    )
353
354
                else:
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
355

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
            # if cudagraph_mode has full cudagraphs, we need to check support
            if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
                # decode context parallel does not support full cudagraphs
                if self.parallel_config.decode_context_parallel_size > 1:
                    logger.warning_once(
                        "Decode context parallel (DCP) is enabled, which is "
                        "incompatible with full CUDA graphs. "
                        "Overriding cudagraph_mode to PIECEWISE."
                    )
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
                elif self.model_config is not None:
                    if self.model_config.pooler_config is not None:
                        logger.warning_once(
                            "Pooling models do not support full cudagraphs. "
                            "Overriding cudagraph_mode to PIECEWISE."
                        )
372
                        self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
                    elif self.model_config.is_encoder_decoder:
                        logger.warning_once(
                            "Encoder-decoder models do not support full cudagraphs. "
                            "Overriding cudagraph_mode to PIECEWISE."
                        )
                        self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
                    elif (
                        current_platform.is_cuda()
                        and current_platform.is_device_capability(100)
                        and self.model_config.max_model_len > 131072
                        and not self.model_config.use_mla
                    ):
                        # Refer to vllm/utils/flashinfer.py::use_trtllm_attention()
                        logger.warning_once(
                            "NVIDIA Blackwell TRTLLM attention cannot support "
                            "max_model_len >= 131072 (found "
                            f"{self.model_config.max_model_len}), causing dynamic "
                            "dispatching that breaks full cudagraphs. "
                            "Overriding cudagraph_mode to PIECEWISE."
392
393
                        )
                        self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
394
395

            # disable cudagraph when enforce eager execution
396
            if self.model_config is not None and self.model_config.enforce_eager:
397
398
                logger.info("Cudagraph is disabled under eager mode")
                self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
399
400
401
                # override related settings when enforce eager
                self.compilation_config.max_cudagraph_capture_size = 0
                self.compilation_config.cudagraph_capture_sizes = []
402
403
404
405
406
407
408
409
            elif envs.VLLM_USE_V1:
                self.compilation_config.cudagraph_num_of_warmups = 1

            self._set_cudagraph_sizes()
        else:
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

        if self.cache_config.kv_sharing_fast_prefill:
410
411
412
413
            if (
                self.speculative_config is not None
                and self.speculative_config.use_eagle()
            ):
414
415
416
417
                raise NotImplementedError(
                    "Fast prefill optimization for KV sharing is not "
                    "compatible with EAGLE as EAGLE requires correct logits "
                    "for all tokens while fast prefill gives incorrect logits "
418
419
                    "for prompt tokens."
                )
420
421
422

            logger.warning_once(
                "--kv-sharing-fast-prefill requires changes on model side for "
423
424
                "correctness and to realize prefill savings. "
            )
425
426
427
428
429
430
431
432

        disable_chunked_prefill_reasons: list[str] = []

        if self.model_config:
            if self.model_config.pooler_config:
                pooling_type = self.model_config.pooler_config.pooling_type
                if pooling_type is None or pooling_type.lower() != "last":
                    disable_chunked_prefill_reasons.append(
433
434
435
                        'Only "last" pooling supports chunked '
                        "prefill and prefix caching; disabling both."
                    )
436
437
438
                if not getattr(self.model_config.hf_config, "is_causal", True):
                    disable_chunked_prefill_reasons.append(
                        "Only models using causal attention supports chunked "
439
440
                        "prefill and prefix caching; disabling both."
                    )
441
442
            elif self.model_config.is_encoder_decoder:
                from vllm.multimodal import MULTIMODAL_REGISTRY
443
444

                self.scheduler_config.max_num_encoder_input_tokens = (
445
                    MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
446
                )
447
448
449
                logger.debug(
                    "Encoder-decoder model detected: setting "
                    "`max_num_encoder_input_tokens` to encoder length (%s)",
450
451
452
453
454
455
                    self.scheduler_config.max_num_encoder_input_tokens,
                )
                if (
                    self.model_config.architecture == "WhisperForConditionalGeneration"
                    and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
                ):
456
457
458
459
                    logger.warning(
                        "Whisper is known to have issues with "
                        "forked workers. If startup is hanging, "
                        "try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
460
461
                        "to 'spawn'."
                    )
462

463
464
465
466
        # Final off-switch for CP/APC:
        # Disable for (a) collected blockers, (b) encoder–decoder, or
        # (c) explicit CP=False when APC wasn't requested.
        # Do NOT disable merely because the resolved CP flag is False.
467
468
469
470
471
472
473
474
475
476
477
        apc_requested = (
            self.cache_config is not None and self.cache_config.enable_prefix_caching
        )
        if (
            disable_chunked_prefill_reasons
            or (self.model_config is not None and self.model_config.is_encoder_decoder)
            or (
                self.scheduler_config.enable_chunked_prefill is False
                and not apc_requested
            )
        ):
478
479
480
481
482
483
484
485
            for reason in disable_chunked_prefill_reasons:
                logger.info(reason)
            self.scheduler_config.chunked_prefill_enabled = False
            self.scheduler_config.long_prefill_token_threshold = 0

            if self.cache_config is not None:
                self.cache_config.enable_prefix_caching = False

486
487
488
489
490
        if (
            self.kv_events_config is not None
            and self.kv_events_config.enable_kv_cache_events
            and not self.cache_config.enable_prefix_caching
        ):
491
492
            logger.warning(
                "KV cache events are on, but prefix caching is not enabled."
493
494
495
496
497
498
499
500
501
502
503
504
505
                "Use --enable-prefix-caching to enable."
            )
        if (
            self.kv_events_config is not None
            and self.kv_events_config.publisher != "null"
            and not self.kv_events_config.enable_kv_cache_events
        ):
            logger.warning(
                "KV cache events are disabled,"
                "but the scheduler is configured to publish them."
                "Modify KVEventsConfig.enable_kv_cache_events"
                "to True to enable."
            )
506
507
        current_platform.check_and_update_config(self)

508
        # Do this after all the updates to compilation_config.mode
509
510
        if (
            envs.VLLM_USE_V1
511
            and self.compilation_config.mode == CompilationMode.VLLM_COMPILE
512
        ):
513
514
515
516
            self.compilation_config.set_splitting_ops_for_v1()

        # final check of cudagraph mode after all possible updates
        if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
517
518
519
520
            if (
                self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                and self.model_config is not None
                and not self.model_config.disable_cascade_attn
521
                and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()  # noqa: E501
522
            ):
523
524
525
                logger.warning_once(
                    "No piecewise cudagraph for executing cascade attention."
                    " Will fall back to eager execution if a batch runs "
526
527
528
529
                    "into cascade attentions"
                )

            if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
530
531
                assert self.compilation_config.mode == CompilationMode.VLLM_COMPILE, (
                    "Compilation mode should be CompilationMode.VLLM_COMPILE "
532
                    "when cudagraph_mode piecewise cudagraphs is used, "
533
                    f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
534
                )
535
536

            # final migrate the deprecated flags
537
538
539
540
541
542
            self.compilation_config.use_cudagraph = (
                self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            )
            self.compilation_config.full_cuda_graph = (
                self.compilation_config.cudagraph_mode.has_full_cudagraphs()
            )
543
544

        if self.parallel_config.enable_dbo:
545
            a2a_backend = self.parallel_config.all2all_backend
546
547
548
            assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
                "Microbatching currently only supports the deepep_low_latency and "
                f"deepep_high_throughput all2all backend. {a2a_backend} is not "
549
550
551
                "supported. To fix use --all2all-backend=deepep_low_latency or "
                "--all2all-backend=deepep_high_throughput and install the DeepEP"
                " kernels."
552
            )
553
554
555

            if not self.model_config.disable_cascade_attn:
                self.model_config.disable_cascade_attn = True
556
                logger.warning_once("Disabling cascade attention when DBO is enabled.")
557
558
559
560

        if not self.instance_id:
            self.instance_id = random_uuid()[:5]

561
562
563
564
        if (
            envs.VLLM_USE_V1
            and not self.scheduler_config.disable_hybrid_kv_cache_manager
        ):
565
566
567
568
569
570
571
572
573
574
575
576
            # logger should only print warning message for hybrid models. As we
            # can't know whether the model is hybrid or not now, so we don't log
            # warning message here and will log it later.
            if not current_platform.support_hybrid_kv_cache():
                # Hybrid KV cache manager is not supported on non-GPU platforms.
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
            if self.kv_transfer_config is not None:
                # Hybrid KV cache manager is not compatible with KV transfer.
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
            if self.kv_events_config is not None:
                # Hybrid KV cache manager is not compatible with KV events.
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
577
578
579
580
581
582
583
584
            if (
                self.model_config is not None
                and self.model_config.attention_chunk_size is not None
            ):
                if (
                    self.speculative_config is not None
                    and self.speculative_config.use_eagle()
                ):
585
586
587
                    # Hybrid KV cache manager is not yet supported with chunked
                    # local attention + eagle.
                    self.scheduler_config.disable_hybrid_kv_cache_manager = True
588
                elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
589
590
591
592
593
594
595
596
597
598
599
                    logger.warning(
                        "There is a latency regression when using chunked local"
                        " attention with the hybrid KV cache manager. Disabling"
                        " it, by default. To enable it, set the environment "
                        "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
                    )
                    # Hybrid KV cache manager is not yet supported with chunked
                    # local attention.
                    self.scheduler_config.disable_hybrid_kv_cache_manager = True

        if self.compilation_config.debug_dump_path:
600
            self.compilation_config.debug_dump_path = (
601
                self.compilation_config.debug_dump_path.absolute().expanduser()
602
            )
603
604
605
606
607
        if envs.VLLM_DEBUG_DUMP_PATH is not None:
            env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
            if self.compilation_config.debug_dump_path:
                logger.warning(
                    "Config-specified debug dump path is overridden"
608
609
610
                    " by VLLM_DEBUG_DUMP_PATH to %s",
                    env_path,
                )
611
612
            self.compilation_config.debug_dump_path = env_path

613
614
615
616
617
618
619
620
621
622
623
624
625
626
        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
627
            if "-quant_fp8" not in custom_ops:
628
629
                custom_ops.append("+quant_fp8")

630
    def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
631
632
633
        # remove the sizes that not multiple of tp_size when
        # enable sequence parallelism
        removed_sizes = [
634
635
            size
            for size in possible_sizes
636
637
638
639
640
641
            if size % self.parallel_config.tensor_parallel_size != 0
        ]
        if removed_sizes:
            logger.warning(
                "Batch sizes %s are removed because they are not "
                "multiple of tp_size %d when "
642
643
644
645
                "sequence parallelism is enabled",
                removed_sizes,
                self.parallel_config.tensor_parallel_size,
            )
646
647

        return [
648
649
            size
            for size in possible_sizes
650
651
652
653
654
655
656
657
658
659
            if size % self.parallel_config.tensor_parallel_size == 0
        ]

    def _set_cudagraph_sizes(self):
        """
        vLLM defines the default candidate list of batch sizes for CUDA graph
        capture as:

        ```python
        max_graph_size = min(max_num_seqs * 2, 512)
660
661
662
663
        # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
        # up to max_graph_size
        cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
            range(256, max_graph_size + 1, 16))
664
665

        In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
666
        will be the final sizes to capture cudagraph (in ascending order).
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692

        These sizes are used to capture and reuse CUDA graphs for
        performance-critical paths (e.g., decoding). Capturing enables
        significantly faster kernel dispatch by avoiding Python overhead. The
        list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
        most GPUs), which controls the total allowed number of tokens in a
        batch. Since each sequence may have a variable number of tokens, the
        maximum usable batch size will depend on actual sequence lengths.

        Example:
            With `max_num_batched_tokens = 8192`, and typical sequences
            averaging ~32 tokens, most practical batch sizes fall below 256.
            However, the system will still allow capture sizes up to 512 if
            shape and memory permit.

        Note:
            If users explicitly specify cudagraph capture sizes in the
            compilation config, those will override this default logic.
            At runtime:

            - If batch size <= one of the `cudagraph_capture_sizes`, the closest
            padded CUDA graph will be used.
            - If batch size > largest `cudagraph_capture_sizes`, cudagraph will
            not be used.
        """

693
694
695
696
697
698
699
700
701
702
703
704
        if (
            self.model_config is not None
            and not self.model_config.enforce_eager
            and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
            # determine the initial max_cudagraph_capture_size
            max_cudagraph_capture_size = (
                self.compilation_config.max_cudagraph_capture_size
            )
            if max_cudagraph_capture_size is None:
                max_cudagraph_capture_size = min(
                    self.scheduler_config.max_num_seqs * 2, 512
705
                )
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
            max_num_tokens = self.scheduler_config.max_num_batched_tokens
            max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)

            assert max_cudagraph_capture_size >= 1, (
                "Maximum cudagraph size should be greater than or equal to 1 "
                "when using cuda graph."
            )

            # determine the cudagraph_capture_sizes
            if self.compilation_config.cudagraph_capture_sizes is not None:
                assert len(self.compilation_config.cudagraph_capture_sizes) > 0, (
                    "cudagraph_capture_sizes should contain at least one element "
                    "when using cuda graph."
                )
                # de-duplicate the sizes provided by the config
                dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
                cudagraph_capture_sizes = dedup_sizes
                # sort to make sure the sizes are in ascending order
                cudagraph_capture_sizes.sort()
725
            else:
726
727
728
729
730
731
732
733
734
735
736
737
738
739
                cudagraph_capture_sizes = [
                    i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
                ]
                if max_cudagraph_capture_size >= 8:
                    # Step size 8 for small batch sizes, up to 256(not included)
                    cudagraph_capture_sizes += list(
                        range(8, min(max_cudagraph_capture_size + 1, 256), 8)
                    )
                if max_cudagraph_capture_size >= 256:
                    # Step size 16 for larger batch sizes
                    cudagraph_capture_sizes += list(
                        range(256, max_cudagraph_capture_size + 1, 16)
                    )

740
741
742
743
            if (
                self.parallel_config.tensor_parallel_size > 1
                and self.compilation_config.pass_config.enable_sequence_parallelism
            ):
744
745
                cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
                    cudagraph_capture_sizes
746
                )
747

748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
            # user-specific compilation_config.max_cudagraph_capture_size get
            # truncated to valid_max_size when they are inconsistent.
            valid_max_size = (
                cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
            )
            if (
                self.compilation_config.max_cudagraph_capture_size is not None
                and self.compilation_config.max_cudagraph_capture_size != valid_max_size
            ):
                # raise error only when both two flags are user-specified
                # and they are inconsistent with each other
                if self.compilation_config.cudagraph_capture_sizes is not None:
                    raise ValueError(
                        "customized max_cudagraph_capture_size"
                        f"(={self.compilation_config.max_cudagraph_capture_size}) "
                        "should be consistent with the max value of "
                        f"cudagraph_capture_sizes(={valid_max_size})"
                    )

                logger.warning(
                    "Truncating max_cudagraph_capture_size to %d",
                    valid_max_size,
                )
            # always set the final max_cudagraph_capture_size
            self.compilation_config.max_cudagraph_capture_size = valid_max_size

            if self.compilation_config.cudagraph_capture_sizes is not None and len(
                cudagraph_capture_sizes
            ) < len(self.compilation_config.cudagraph_capture_sizes):
                # If users have specified capture sizes, we only need to
                # compare the lens before and after modification since the modified
                # list is only the subset of the original list.
                logger.warning(
                    (
                        "cudagraph_capture_sizes specified in compilation_config"
                        " %s is overridden by config %s"
                    ),
                    self.compilation_config.cudagraph_capture_sizes,
                    cudagraph_capture_sizes,
                )
            # always write back the final sizes
            self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes

        else:
            # no cudagraph in use
            self.compilation_config.max_cudagraph_capture_size = 0
            self.compilation_config.cudagraph_capture_sizes = []

        # complete the remaining process.
        self.compilation_config.post_init_cudagraph_sizes()
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819

    def recalculate_max_model_len(self, max_model_len: int):
        # Can only be called in try_verify_and_update_config
        model_config = self.model_config
        max_model_len = model_config.get_and_verify_max_len(max_model_len)
        self.model_config.max_model_len = max_model_len
        self.scheduler_config.max_model_len = max_model_len

    def try_verify_and_update_config(self):
        if self.model_config is None:
            return

        # Avoid running try_verify_and_update_config multiple times
        if getattr(self.model_config, "config_updated", False):
            return
        self.model_config.config_updated = True

        architecture = self.model_config.architecture
        if architecture is None:
            return

        from vllm.model_executor.models.config import (
820
821
822
823
            MODELS_CONFIG_MAP,
            HybridAttentionMambaModelConfig,
        )

824
825
826
827
828
829
830
831
832
        cls = MODELS_CONFIG_MAP.get(architecture, None)
        if cls is not None:
            cls.verify_and_update_config(self)

        if self.model_config.is_hybrid:
            HybridAttentionMambaModelConfig.verify_and_update_config(self)

        if self.model_config.convert_type == "classify":
            # Maybe convert ForCausalLM into ForSequenceClassification model.
833
834
            from vllm.model_executor.models.adapters import SequenceClassificationConfig

835
836
837
            SequenceClassificationConfig.verify_and_update_config(self)

        if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
838
839
            self.model_config.model_weights
        ):
840
            if self.load_config.load_format == "auto":
841
842
843
844
                logger.info(
                    "Detected Run:ai model config. "
                    "Overriding `load_format` to 'runai_streamer'"
                )
845
                self.load_config.load_format = "runai_streamer"
846
847
848
849
            elif self.load_config.load_format not in (
                "runai_streamer",
                "runai_streamer_sharded",
            ):
850
851
                raise ValueError(
                    f"To load a model from S3, 'load_format' "
852
                    f"must be 'runai_streamer' or 'runai_streamer_sharded', "
853
854
855
                    f"but got '{self.load_config.load_format}'. "
                    f"Model: {self.model_config.model}"
                )
856

857
    def compile_debug_dump_path(self) -> Path | None:
858
        """Returns a rank-aware path for dumping
859
860
861
862
863
864
865
        torch.compile debug information.
        """
        if self.compilation_config.debug_dump_path is None:
            return None
        tp_rank = self.parallel_config.rank
        dp_rank = self.parallel_config.data_parallel_rank
        data_parallel_size = self.parallel_config.data_parallel_size
866
867
868
        append_path = (
            f"rank_{tp_rank}"
            if data_parallel_size == 1
869
            else f"rank_{tp_rank}_dp_{dp_rank}"
870
        )
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
        path = self.compilation_config.debug_dump_path / append_path
        return path

    def __str__(self):
        return (
            f"model={self.model_config.model!r}, "
            f"speculative_config={self.speculative_config!r}, "
            f"tokenizer={self.model_config.tokenizer!r}, "
            f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
            f"tokenizer_mode={self.model_config.tokenizer_mode}, "
            f"revision={self.model_config.revision}, "
            f"tokenizer_revision={self.model_config.tokenizer_revision}, "
            f"trust_remote_code={self.model_config.trust_remote_code}, "
            f"dtype={self.model_config.dtype}, "
            f"max_seq_len={self.model_config.max_model_len}, "
            f"download_dir={self.load_config.download_dir!r}, "
            f"load_format={self.load_config.load_format}, "
            f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, "  # noqa
            f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, "  # noqa
            f"data_parallel_size={self.parallel_config.data_parallel_size}, "  # noqa
            f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, "  # noqa
            f"quantization={self.model_config.quantization}, "
            f"enforce_eager={self.model_config.enforce_eager}, "
            f"kv_cache_dtype={self.cache_config.cache_dtype}, "
            f"device_config={self.device_config.device}, "
            f"structured_outputs_config={self.structured_outputs_config!r}, "
            f"observability_config={self.observability_config!r}, "
            f"seed={self.model_config.seed}, "
            f"served_model_name={self.model_config.served_model_name}, "
            f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
            f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, "  # noqa
            f"pooler_config={self.model_config.pooler_config!r}, "
903
904
            f"compilation_config={self.compilation_config!r}"
        )
905
906


907
908
_current_vllm_config: VllmConfig | None = None
_current_prefix: str | None = None
909
910
911


@contextmanager
912
def set_current_vllm_config(
913
    vllm_config: VllmConfig, check_compile=False, prefix: str | None = None
914
):
915
916
917
918
919
920
921
922
923
924
925
    """
    Temporarily set the current vLLM config.
    Used during model initialization.
    We save the current vLLM config in a global variable,
    so that all modules can access it, e.g. custom ops
    can access the vLLM config to determine how to dispatch.
    """
    global _current_vllm_config, _current_prefix
    old_vllm_config = _current_vllm_config
    old_prefix = _current_prefix
    from vllm.compilation.counter import compilation_counter
926

927
928
929
930
931
932
933
934
935
936
937
    num_models_seen = compilation_counter.num_models_seen
    try:
        _current_vllm_config = vllm_config
        _current_prefix = prefix
        yield
    except Exception:
        raise
    else:
        if check_compile:
            vllm_config.compilation_config.custom_op_log_check()

938
939
        if (
            check_compile
940
            and vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
941
942
            and compilation_counter.num_models_seen == num_models_seen
        ):
943
944
945
946
947
948
949
950
951
            # If the model supports compilation,
            # compilation_counter.num_models_seen should be increased
            # by at least 1.
            # If it is not increased, it means the model does not support
            # compilation (does not have @support_torch_compile decorator).
            logger.warning(
                "`torch.compile` is turned on, but the model %s"
                " does not support it. Please open an issue on GitHub"
                " if you want it to be supported.",
952
953
                vllm_config.model_config.model,
            )
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
    finally:
        _current_vllm_config = old_vllm_config
        _current_prefix = old_prefix
        # Clear the compilation config cache when context changes
        get_cached_compilation_config.cache_clear()


@lru_cache(maxsize=1)
def get_cached_compilation_config():
    """Cache config to avoid repeated calls to get_current_vllm_config()"""
    return get_current_vllm_config().compilation_config


def get_current_vllm_config() -> VllmConfig:
    if _current_vllm_config is None:
        # in ci, usually when we test custom ops/modules directly,
        # we don't set the vllm config. In that case, we set a default
        # config.
        logger.warning("Current vLLM config is not set.")
        return VllmConfig()
    return _current_vllm_config


T = TypeVar("T")


def get_layers_from_vllm_config(
981
982
    vllm_config: VllmConfig,
    layer_type: type[T],
983
    layer_names: list[str] | None = None,
984
) -> dict[str, T]:
985
986
987
988
989
990
991
992
993
994
    """
    Get layers from the vLLM config.

    Args:
        vllm_config: The vLLM config.
        layer_type: The type of the layer to get.
        layer_names: The names of the layers to get. If None, return all layers.
    """

    if layer_names is None:
995
        layer_names = list(vllm_config.compilation_config.static_forward_context.keys())
996
997
998
999
1000
1001
1002
1003

    forward_context = vllm_config.compilation_config.static_forward_context

    return {
        layer_name: forward_context[layer_name]
        for layer_name in layer_names
        if isinstance(forward_context[layer_name], layer_type)
    }