vllm.py 38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
import hashlib
import json
import os
from contextlib import contextmanager
from dataclasses import field, replace
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union

import torch
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid

from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationLevel, CUDAGraphMode
from .device import DeviceConfig
from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .parallel import ParallelConfig
from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config

if TYPE_CHECKING:
    from transformers import PretrainedConfig

41
    from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
else:
    PretrainedConfig = Any

    QuantizationConfig = Any

logger = init_logger(__name__)


@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
    """Dataclass which contains all vllm-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    # TODO: use default_factory once default constructing ModelConfig doesn't
    # try to download a model
    model_config: ModelConfig = None  # type: ignore
    """Model configuration."""
    cache_config: CacheConfig = field(default_factory=CacheConfig)
    """Cache configuration."""
    parallel_config: ParallelConfig = field(default_factory=ParallelConfig)
    """Parallel configuration."""
    scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig)
    """Scheduler configuration."""
    device_config: DeviceConfig = field(default_factory=DeviceConfig)
    """Device configuration."""
    load_config: LoadConfig = field(default_factory=LoadConfig)
    """Load configuration."""
    lora_config: Optional[LoRAConfig] = None
    """LoRA configuration."""
    speculative_config: Optional[SpeculativeConfig] = None
    """Speculative decoding configuration."""
    structured_outputs_config: StructuredOutputsConfig = field(
76
77
        default_factory=StructuredOutputsConfig
    )
78
79
80
81
82
    """Structured outputs configuration."""
    observability_config: Optional[ObservabilityConfig] = None
    """Observability configuration."""
    quant_config: Optional[QuantizationConfig] = None
    """Quantization configuration."""
83
    compilation_config: CompilationConfig = field(default_factory=CompilationConfig)
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    """`torch.compile` and cudagraph capture configuration for the model.

    As a shorthand, `-O<n>` can be used to directly specify the compilation
    level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
    Currently, -O <n> and -O=<n> are supported as well but this will likely be
    removed in favor of clearer -O<n> syntax in the future.

    NOTE: level 0 is the default level without any optimization. level 1 and 2
    are for internal testing only. level 3 is the recommended level for
    production, also default in V1.

    You can specify the full compilation config like so:
    `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
    """
    kv_transfer_config: Optional[KVTransferConfig] = None
    """The configurations for distributed KV cache transfer."""
    kv_events_config: Optional[KVEventsConfig] = None
    """The configurations for event publishing."""
    # some opaque config, only used to provide additional information
    # for the hash computation, mainly used for testing, debugging or out of
    # tree config registration.
    additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
    """Additional config for specified platform. Different platforms may
    support different configs. Make sure the configs are valid for the platform
    you are using. Contents must be hashable."""
    instance_id: str = ""
    """The ID of the vLLM instance."""

    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []

        # summarize vllm config
        vllm_factors: list[Any] = []
        from vllm import __version__
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        vllm_factors.append(__version__)
        vllm_factors.append(envs.VLLM_USE_V1)
        if self.model_config:
            vllm_factors.append(self.model_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.cache_config:
            vllm_factors.append(self.cache_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.parallel_config:
            vllm_factors.append(self.parallel_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.scheduler_config:
            vllm_factors.append(self.scheduler_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.device_config:
            vllm_factors.append(self.device_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.load_config:
            vllm_factors.append(self.load_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.lora_config:
            vllm_factors.append(self.lora_config.compute_hash())
            # LoRA creates static buffers based on max_num_batched_tokens.
            # The tensor sizes and strides get captured in the torch.compile
            # graph explicitly.
161
            vllm_factors.append(str(self.scheduler_config.max_num_batched_tokens))
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        else:
            vllm_factors.append("None")
        if self.speculative_config:
            vllm_factors.append(self.speculative_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.structured_outputs_config:
            vllm_factors.append(self.structured_outputs_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.observability_config:
            vllm_factors.append(self.observability_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.quant_config:
            pass  # should be captured by model_config.quantization
        if self.compilation_config:
            vllm_factors.append(self.compilation_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.kv_transfer_config:
            vllm_factors.append(self.kv_transfer_config.compute_hash())
        else:
            vllm_factors.append("None")
        if self.additional_config:
            if isinstance(additional_config := self.additional_config, dict):
                additional_config_hash = hashlib.md5(
                    json.dumps(additional_config, sort_keys=True).encode(),
                    usedforsecurity=False,
                ).hexdigest()
            else:
                additional_config_hash = additional_config.compute_hash()
            vllm_factors.append(additional_config_hash)
        else:
            vllm_factors.append("None")
        factors.append(vllm_factors)

199
200
201
        hash_str = hashlib.md5(
            str(factors).encode(), usedforsecurity=False
        ).hexdigest()[:10]
202
203
204
205
206
207
208
209
210
211
212
        return hash_str

    def pad_for_cudagraph(self, batch_size: int) -> int:
        # if batch_size > self.compilation_config.max_capture_size,
        # it should raise an IndexError.
        # the caller should make sure the batch_size is within the range,
        # i.e., batch_size <= self.compilation_config.max_capture_size
        return self.compilation_config.bs_to_padded_graph_size[batch_size]

    @staticmethod
    def _get_quantization_config(
213
214
        model_config: ModelConfig, load_config: LoadConfig
    ) -> Optional[QuantizationConfig]:
215
216
        """Get the quantization config."""
        from vllm.platforms import current_platform
217

218
        if model_config.quantization is not None:
219
220
            from vllm.model_executor.model_loader.weight_utils import get_quant_config

221
222
223
224
225
226
227
228
229
230
            quant_config = get_quant_config(model_config, load_config)
            capability_tuple = current_platform.get_device_capability()

            if capability_tuple is not None:
                capability = capability_tuple.to_int()
                if capability < quant_config.get_min_capability():
                    raise ValueError(
                        f"The quantization method {model_config.quantization} "
                        "is not supported for the current GPU. Minimum "
                        f"capability: {quant_config.get_min_capability()}. "
231
232
                        f"Current capability: {capability}."
                    )
233
234
235
236
237
            supported_dtypes = quant_config.get_supported_act_dtypes()
            if model_config.dtype not in supported_dtypes:
                raise ValueError(
                    f"{model_config.dtype} is not supported for quantization "
                    f"method {model_config.quantization}. Supported dtypes: "
238
239
                    f"{supported_dtypes}"
                )
240
241
242
243
244
245
            quant_config.maybe_update_config(model_config.model)
            return quant_config
        return None

    @staticmethod
    def get_quantization_config(
246
247
        model_config: ModelConfig, load_config: LoadConfig
    ) -> Optional[QuantizationConfig]:
248
249
250
251
        import copy

        # For some reason, the _ version of this modifies the model_config
        # object, so using deepcopy to avoid this problem.
252
253
254
        return VllmConfig._get_quantization_config(
            copy.deepcopy(model_config), load_config
        )
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270

    def with_hf_config(
        self,
        hf_config: PretrainedConfig,
        architectures: Optional[list[str]] = None,
    ) -> "VllmConfig":
        if architectures is not None:
            hf_config = copy.deepcopy(hf_config)
            hf_config.architectures = architectures

        model_config = copy.deepcopy(self.model_config)
        model_config.hf_config = hf_config

        return replace(self, model_config=model_config)

    def __post_init__(self):
271
        """Verify configs are valid & consistent with each other."""
272
273
274
275
276

        self.try_verify_and_update_config()

        if self.model_config is not None:
            self.model_config.verify_with_parallel_config(self.parallel_config)
277
            self.model_config.verify_dual_chunk_attention_config(self.load_config)
278
279
280
281
282
283
284
285
286

        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config is not None:
            self.lora_config.verify_with_cache_config(self.cache_config)
            self.lora_config.verify_with_model_config(self.model_config)

        if self.quant_config is None and self.model_config is not None:
            self.quant_config = VllmConfig._get_quantization_config(
287
288
                self.model_config, self.load_config
            )
289
290

        from vllm.platforms import current_platform
291
292
293
294
295
296
297

        if (
            self.model_config is not None
            and self.scheduler_config.chunked_prefill_enabled
            and self.model_config.dtype == torch.float32
            and current_platform.get_device_capability() == (7, 5)
        ):
298
299
300
            logger.warning_once(
                "Turing devices tensor cores do not support float32 matmul. "
                "To workaround this limitation, vLLM will set 'ieee' input "
301
302
                "precision for chunked prefill triton kernels."
            )
303
304
305
306
307
308

        # If the user does not explicitly set a compilation level, then
        # we use the default level. The default level depends on other
        # settings (see the below code).
        if self.compilation_config.level is None:
            if envs.VLLM_USE_V1:
309
310
311
312
                if (
                    self.model_config is not None
                    and not self.model_config.enforce_eager
                ):
313
314
                    self.compilation_config.level = CompilationLevel.PIECEWISE
                else:
315
                    self.compilation_config.level = CompilationLevel.NO_COMPILATION
316
317
318
319
320

            else:
                # NB: Passing both --enforce-eager and a compilation level
                # in V0 means the compilation level wins out.
                self.compilation_config.level = CompilationLevel.NO_COMPILATION
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        else:
            assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
            assert self.compilation_config.level <= CompilationLevel.PIECEWISE
            assert self.compilation_config.level <= 3

        # If user does not set custom ops via none or all set it here based on
        # compilation level and backend.
        if (
            self.compilation_config.custom_ops.count("none")
            + self.compilation_config.custom_ops.count("all")
            == 0
        ):
            if (
                self.compilation_config.level > 0
                and self.compilation_config.backend != "eager"
            ):
                self.compilation_config.custom_ops.append("none")
            else:
                self.compilation_config.custom_ops.append("all")
340
341
342
343

        # async tp is built on top of sequence parallelism
        # and requires it to be enabled.
        if self.compilation_config.pass_config.enable_async_tp:
344
            self.compilation_config.pass_config.enable_sequence_parallelism = True
345
346
347
348
349
350
351
        if self.compilation_config.pass_config.enable_sequence_parallelism:
            self.compilation_config.custom_ops.append("+rms_norm")

        if current_platform.support_static_graph_mode():
            # if cudagraph_mode is not explicitly set by users, set default
            # value
            if self.compilation_config.cudagraph_mode is None:
352
353
354
355
                if (
                    envs.VLLM_USE_V1
                    and self.compilation_config.level == CompilationLevel.PIECEWISE
                ):
356
                    # default to full and piecewise for most models
357
                    self.compilation_config.cudagraph_mode = (
358
                        CUDAGraphMode.FULL_AND_PIECEWISE
359
                    )
360
361
362

                    # pooling models and encoder-decoder models
                    # do not support full cudagraphs
363
364
365
366
367
                    if self.model_config is not None and (
                        self.model_config.pooler_config is not None
                        or self.model_config.is_encoder_decoder
                    ):
                        self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
368
369
370
371
                else:
                    self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

            # disable cudagraph when enforce eager execution
372
            if self.model_config is not None and self.model_config.enforce_eager:
373
374
375
376
377
378
379
380
381
382
                logger.info("Cudagraph is disabled under eager mode")
                self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
            elif envs.VLLM_USE_V1:
                self.compilation_config.cudagraph_num_of_warmups = 1

            self._set_cudagraph_sizes()
        else:
            self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE

        if self.cache_config.kv_sharing_fast_prefill:
383
384
385
386
            if (
                self.speculative_config is not None
                and self.speculative_config.use_eagle()
            ):
387
388
389
390
                raise NotImplementedError(
                    "Fast prefill optimization for KV sharing is not "
                    "compatible with EAGLE as EAGLE requires correct logits "
                    "for all tokens while fast prefill gives incorrect logits "
391
392
                    "for prompt tokens."
                )
393
394
395

            logger.warning_once(
                "--kv-sharing-fast-prefill requires changes on model side for "
396
397
                "correctness and to realize prefill savings. "
            )
398
399
400
401
402
403
404
405

        disable_chunked_prefill_reasons: list[str] = []

        if self.model_config:
            if self.model_config.pooler_config:
                pooling_type = self.model_config.pooler_config.pooling_type
                if pooling_type is None or pooling_type.lower() != "last":
                    disable_chunked_prefill_reasons.append(
406
407
408
                        'Only "last" pooling supports chunked '
                        "prefill and prefix caching; disabling both."
                    )
409
410
411
                if not getattr(self.model_config.hf_config, "is_causal", True):
                    disable_chunked_prefill_reasons.append(
                        "Only models using causal attention supports chunked "
412
413
                        "prefill and prefix caching; disabling both."
                    )
414
415
            elif self.model_config.is_encoder_decoder:
                from vllm.multimodal import MULTIMODAL_REGISTRY
416
417

                self.scheduler_config.max_num_encoder_input_tokens = (
418
                    MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
419
                )
420
421
422
                logger.debug(
                    "Encoder-decoder model detected: setting "
                    "`max_num_encoder_input_tokens` to encoder length (%s)",
423
424
425
426
427
428
                    self.scheduler_config.max_num_encoder_input_tokens,
                )
                if (
                    self.model_config.architecture == "WhisperForConditionalGeneration"
                    and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
                ):
429
430
431
432
                    logger.warning(
                        "Whisper is known to have issues with "
                        "forked workers. If startup is hanging, "
                        "try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
433
434
                        "to 'spawn'."
                    )
435

436
437
438
439
        # Final off-switch for CP/APC:
        # Disable for (a) collected blockers, (b) encoder–decoder, or
        # (c) explicit CP=False when APC wasn't requested.
        # Do NOT disable merely because the resolved CP flag is False.
440
441
442
443
444
445
446
447
448
449
450
        apc_requested = (
            self.cache_config is not None and self.cache_config.enable_prefix_caching
        )
        if (
            disable_chunked_prefill_reasons
            or (self.model_config is not None and self.model_config.is_encoder_decoder)
            or (
                self.scheduler_config.enable_chunked_prefill is False
                and not apc_requested
            )
        ):
451
452
453
454
455
456
457
458
            for reason in disable_chunked_prefill_reasons:
                logger.info(reason)
            self.scheduler_config.chunked_prefill_enabled = False
            self.scheduler_config.long_prefill_token_threshold = 0

            if self.cache_config is not None:
                self.cache_config.enable_prefix_caching = False

459
460
461
462
463
        if (
            self.kv_events_config is not None
            and self.kv_events_config.enable_kv_cache_events
            and not self.cache_config.enable_prefix_caching
        ):
464
465
            logger.warning(
                "KV cache events are on, but prefix caching is not enabled."
466
467
468
469
470
471
472
473
474
475
476
477
478
                "Use --enable-prefix-caching to enable."
            )
        if (
            self.kv_events_config is not None
            and self.kv_events_config.publisher != "null"
            and not self.kv_events_config.enable_kv_cache_events
        ):
            logger.warning(
                "KV cache events are disabled,"
                "but the scheduler is configured to publish them."
                "Modify KVEventsConfig.enable_kv_cache_events"
                "to True to enable."
            )
479
480
481
        current_platform.check_and_update_config(self)

        # Do this after all the updates to compilation_config.level
482
483
484
485
        if (
            envs.VLLM_USE_V1
            and self.compilation_config.level == CompilationLevel.PIECEWISE
        ):
486
487
488
489
            self.compilation_config.set_splitting_ops_for_v1()

        # final check of cudagraph mode after all possible updates
        if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
490
491
492
493
            if (
                self.compilation_config.cudagraph_mode.has_full_cudagraphs()
                and self.model_config is not None
                and not self.model_config.disable_cascade_attn
494
                and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()  # noqa: E501
495
            ):
496
497
498
                logger.warning_once(
                    "No piecewise cudagraph for executing cascade attention."
                    " Will fall back to eager execution if a batch runs "
499
500
501
502
503
504
505
                    "into cascade attentions"
                )

            if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
                assert self.compilation_config.level == CompilationLevel.PIECEWISE, (
                    "Compilation level should be CompilationLevel.PIECEWISE "
                    "when cudagraph_mode piecewise cudagraphs is used, "
506
                    f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
507
                )
508
509

            # final migrate the deprecated flags
510
511
512
513
514
515
            self.compilation_config.use_cudagraph = (
                self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            )
            self.compilation_config.full_cuda_graph = (
                self.compilation_config.cudagraph_mode.has_full_cudagraphs()
            )
516
517
518

        if self.parallel_config.enable_dbo:
            a2a_backend = envs.VLLM_ALL2ALL_BACKEND
519
520
521
522
523
524
525
            assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
                "Microbatching currently only supports the deepep_low_latency and "
                f"deepep_high_throughput all2all backend. {a2a_backend} is not "
                "supported. To fix set the VLLM_ALL2ALL_BACKEND environment "
                "variable to deepep_low_latency or deepep_high_throughput and "
                "install the DeepEP kernels."
            )
526
527
528

            if not self.model_config.disable_cascade_attn:
                self.model_config.disable_cascade_attn = True
529
                logger.warning_once("Disabling cascade attention when DBO is enabled.")
530
531
532
533

        if not self.instance_id:
            self.instance_id = random_uuid()[:5]

534
535
536
537
        if (
            envs.VLLM_USE_V1
            and not self.scheduler_config.disable_hybrid_kv_cache_manager
        ):
538
539
540
541
542
543
544
545
546
547
548
549
            # logger should only print warning message for hybrid models. As we
            # can't know whether the model is hybrid or not now, so we don't log
            # warning message here and will log it later.
            if not current_platform.support_hybrid_kv_cache():
                # Hybrid KV cache manager is not supported on non-GPU platforms.
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
            if self.kv_transfer_config is not None:
                # Hybrid KV cache manager is not compatible with KV transfer.
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
            if self.kv_events_config is not None:
                # Hybrid KV cache manager is not compatible with KV events.
                self.scheduler_config.disable_hybrid_kv_cache_manager = True
550
551
552
553
554
555
556
557
            if (
                self.model_config is not None
                and self.model_config.attention_chunk_size is not None
            ):
                if (
                    self.speculative_config is not None
                    and self.speculative_config.use_eagle()
                ):
558
559
560
                    # Hybrid KV cache manager is not yet supported with chunked
                    # local attention + eagle.
                    self.scheduler_config.disable_hybrid_kv_cache_manager = True
561
                elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
562
563
564
565
566
567
568
569
570
571
572
                    logger.warning(
                        "There is a latency regression when using chunked local"
                        " attention with the hybrid KV cache manager. Disabling"
                        " it, by default. To enable it, set the environment "
                        "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
                    )
                    # Hybrid KV cache manager is not yet supported with chunked
                    # local attention.
                    self.scheduler_config.disable_hybrid_kv_cache_manager = True

        if self.compilation_config.debug_dump_path:
573
            self.compilation_config.debug_dump_path = (
574
                self.compilation_config.debug_dump_path.absolute().expanduser()
575
            )
576
577
578
579
580
        if envs.VLLM_DEBUG_DUMP_PATH is not None:
            env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
            if self.compilation_config.debug_dump_path:
                logger.warning(
                    "Config-specified debug dump path is overridden"
581
582
583
                    " by VLLM_DEBUG_DUMP_PATH to %s",
                    env_path,
                )
584
585
            self.compilation_config.debug_dump_path = env_path

586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
        def has_blocked_weights():
            if self.quant_config is not None:
                if hasattr(self.quant_config, "weight_block_size"):
                    return self.quant_config.weight_block_size is not None
                elif hasattr(self.quant_config, "has_blocked_weights"):
                    return self.quant_config.has_blocked_weights()
            return False

        # Enable quant_fp8 CUDA ops (TODO disable in follow up)
        # On H100 the CUDA kernel is faster than
        # native implementation
        # https://github.com/vllm-project/vllm/issues/25094
        if has_blocked_weights():
            custom_ops = self.compilation_config.custom_ops
            if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
                custom_ops.append("+quant_fp8")

603
    def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
604
605
606
        # remove the sizes that not multiple of tp_size when
        # enable sequence parallelism
        removed_sizes = [
607
608
            size
            for size in possible_sizes
609
610
611
612
613
614
            if size % self.parallel_config.tensor_parallel_size != 0
        ]
        if removed_sizes:
            logger.warning(
                "Batch sizes %s are removed because they are not "
                "multiple of tp_size %d when "
615
616
617
618
                "sequence parallelism is enabled",
                removed_sizes,
                self.parallel_config.tensor_parallel_size,
            )
619
620

        return [
621
622
            size
            for size in possible_sizes
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
            if size % self.parallel_config.tensor_parallel_size == 0
        ]

    def _set_cudagraph_sizes(self):
        """
        vLLM defines the default candidate list of batch sizes for CUDA graph
        capture as:

        ```python
        max_graph_size = min(max_num_seqs * 2, 512)
        # 1, 2, 4, then multiples of 8 up to max_graph_size
        cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size]

        In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
        will be the final sizes to capture cudagraph (in descending order).

        These sizes are used to capture and reuse CUDA graphs for
        performance-critical paths (e.g., decoding). Capturing enables
        significantly faster kernel dispatch by avoiding Python overhead. The
        list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on
        most GPUs), which controls the total allowed number of tokens in a
        batch. Since each sequence may have a variable number of tokens, the
        maximum usable batch size will depend on actual sequence lengths.

        Example:
            With `max_num_batched_tokens = 8192`, and typical sequences
            averaging ~32 tokens, most practical batch sizes fall below 256.
            However, the system will still allow capture sizes up to 512 if
            shape and memory permit.

        Note:
            If users explicitly specify cudagraph capture sizes in the
            compilation config, those will override this default logic.
            At runtime:

            - If batch size <= one of the `cudagraph_capture_sizes`, the closest
            padded CUDA graph will be used.
            - If batch size > largest `cudagraph_capture_sizes`, cudagraph will
            not be used.
        """

        # calculate the default `batch_size_capture_list`
        batch_size_capture_list = []
666
        if self.model_config is not None and not self.model_config.enforce_eager:
667
668
            cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes
            if len(cuda_graph_sizes) == 1:
669
                max_graph_size = cuda_graph_sizes[0]
670
671
672
                assert max_graph_size >= 1, (
                    "Maximum cudagraph size should be greater than or equal to 1."
                )
673
674
675
                batch_size_capture_list = [
                    i for i in [1, 2, 4] if i <= max_graph_size
                ] + list(range(8, max_graph_size + 1, 8))
676
677
678
679
            elif len(cuda_graph_sizes) > 1:
                batch_size_capture_list = sorted(cuda_graph_sizes)
            else:
                raise TypeError(f"Invalid value for {cuda_graph_sizes=}.")
680
681
682
683
684
685
686
            if (
                self.parallel_config.tensor_parallel_size > 1
                and self.compilation_config.pass_config.enable_sequence_parallelism
            ):
                batch_size_capture_list = self.update_sizes_for_sequence_parallelism(
                    batch_size_capture_list
                )
687
688
            max_num_tokens = self.scheduler_config.max_num_batched_tokens
            batch_size_capture_list = [
689
                size for size in batch_size_capture_list if size <= max_num_tokens
690
691
            ]

692
        self.compilation_config.init_with_cudagraph_sizes(batch_size_capture_list)
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714

    def recalculate_max_model_len(self, max_model_len: int):
        # Can only be called in try_verify_and_update_config
        model_config = self.model_config
        max_model_len = model_config.get_and_verify_max_len(max_model_len)
        self.model_config.max_model_len = max_model_len
        self.scheduler_config.max_model_len = max_model_len

    def try_verify_and_update_config(self):
        if self.model_config is None:
            return

        # Avoid running try_verify_and_update_config multiple times
        if getattr(self.model_config, "config_updated", False):
            return
        self.model_config.config_updated = True

        architecture = self.model_config.architecture
        if architecture is None:
            return

        from vllm.model_executor.models.config import (
715
716
717
718
            MODELS_CONFIG_MAP,
            HybridAttentionMambaModelConfig,
        )

719
720
721
722
723
724
725
726
727
        cls = MODELS_CONFIG_MAP.get(architecture, None)
        if cls is not None:
            cls.verify_and_update_config(self)

        if self.model_config.is_hybrid:
            HybridAttentionMambaModelConfig.verify_and_update_config(self)

        if self.model_config.convert_type == "classify":
            # Maybe convert ForCausalLM into ForSequenceClassification model.
728
729
            from vllm.model_executor.models.adapters import SequenceClassificationConfig

730
731
732
            SequenceClassificationConfig.verify_and_update_config(self)

        if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
733
734
            self.model_config.model_weights
        ):
735
            if self.load_config.load_format == "auto":
736
737
738
739
                logger.info(
                    "Detected Run:ai model config. "
                    "Overriding `load_format` to 'runai_streamer'"
                )
740
741
                self.load_config.load_format = "runai_streamer"
            elif self.load_config.load_format != "runai_streamer":
742
743
744
745
746
747
                raise ValueError(
                    f"To load a model from S3, 'load_format' "
                    f"must be 'runai_streamer', "
                    f"but got '{self.load_config.load_format}'. "
                    f"Model: {self.model_config.model}"
                )
748
749

    def compile_debug_dump_path(self) -> Optional[Path]:
750
        """Returns a rank-aware path for dumping
751
752
753
754
755
756
757
        torch.compile debug information.
        """
        if self.compilation_config.debug_dump_path is None:
            return None
        tp_rank = self.parallel_config.rank
        dp_rank = self.parallel_config.data_parallel_rank
        data_parallel_size = self.parallel_config.data_parallel_size
758
759
760
        append_path = (
            f"rank_{tp_rank}"
            if data_parallel_size == 1
761
            else f"rank_{tp_rank}_dp_{dp_rank}"
762
        )
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        path = self.compilation_config.debug_dump_path / append_path
        return path

    def __str__(self):
        return (
            f"model={self.model_config.model!r}, "
            f"speculative_config={self.speculative_config!r}, "
            f"tokenizer={self.model_config.tokenizer!r}, "
            f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
            f"tokenizer_mode={self.model_config.tokenizer_mode}, "
            f"revision={self.model_config.revision}, "
            f"tokenizer_revision={self.model_config.tokenizer_revision}, "
            f"trust_remote_code={self.model_config.trust_remote_code}, "
            f"dtype={self.model_config.dtype}, "
            f"max_seq_len={self.model_config.max_model_len}, "
            f"download_dir={self.load_config.download_dir!r}, "
            f"load_format={self.load_config.load_format}, "
            f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, "  # noqa
            f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, "  # noqa
            f"data_parallel_size={self.parallel_config.data_parallel_size}, "  # noqa
            f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, "  # noqa
            f"quantization={self.model_config.quantization}, "
            f"enforce_eager={self.model_config.enforce_eager}, "
            f"kv_cache_dtype={self.cache_config.cache_dtype}, "
            f"device_config={self.device_config.device}, "
            f"structured_outputs_config={self.structured_outputs_config!r}, "
            f"observability_config={self.observability_config!r}, "
            f"seed={self.model_config.seed}, "
            f"served_model_name={self.model_config.served_model_name}, "
            f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
            f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, "  # noqa
            f"pooler_config={self.model_config.pooler_config!r}, "
795
796
            f"compilation_config={self.compilation_config!r}"
        )
797
798
799
800
801
802
803


_current_vllm_config: Optional[VllmConfig] = None
_current_prefix: Optional[str] = None


@contextmanager
804
805
806
def set_current_vllm_config(
    vllm_config: VllmConfig, check_compile=False, prefix: Optional[str] = None
):
807
808
809
810
811
812
813
814
815
816
817
    """
    Temporarily set the current vLLM config.
    Used during model initialization.
    We save the current vLLM config in a global variable,
    so that all modules can access it, e.g. custom ops
    can access the vLLM config to determine how to dispatch.
    """
    global _current_vllm_config, _current_prefix
    old_vllm_config = _current_vllm_config
    old_prefix = _current_prefix
    from vllm.compilation.counter import compilation_counter
818

819
820
821
822
823
824
825
826
827
828
829
    num_models_seen = compilation_counter.num_models_seen
    try:
        _current_vllm_config = vllm_config
        _current_prefix = prefix
        yield
    except Exception:
        raise
    else:
        if check_compile:
            vllm_config.compilation_config.custom_op_log_check()

830
831
832
833
834
        if (
            check_compile
            and vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
            and compilation_counter.num_models_seen == num_models_seen
        ):
835
836
837
838
839
840
841
842
843
            # If the model supports compilation,
            # compilation_counter.num_models_seen should be increased
            # by at least 1.
            # If it is not increased, it means the model does not support
            # compilation (does not have @support_torch_compile decorator).
            logger.warning(
                "`torch.compile` is turned on, but the model %s"
                " does not support it. Please open an issue on GitHub"
                " if you want it to be supported.",
844
845
                vllm_config.model_config.model,
            )
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
    finally:
        _current_vllm_config = old_vllm_config
        _current_prefix = old_prefix
        # Clear the compilation config cache when context changes
        get_cached_compilation_config.cache_clear()


@lru_cache(maxsize=1)
def get_cached_compilation_config():
    """Cache config to avoid repeated calls to get_current_vllm_config()"""
    return get_current_vllm_config().compilation_config


def get_current_vllm_config() -> VllmConfig:
    if _current_vllm_config is None:
        # in ci, usually when we test custom ops/modules directly,
        # we don't set the vllm config. In that case, we set a default
        # config.
        logger.warning("Current vLLM config is not set.")
        return VllmConfig()
    return _current_vllm_config


T = TypeVar("T")


def get_layers_from_vllm_config(
873
874
875
876
    vllm_config: VllmConfig,
    layer_type: type[T],
    layer_names: Optional[list[str]] = None,
) -> dict[str, T]:
877
878
879
880
881
882
883
884
885
886
    """
    Get layers from the vLLM config.

    Args:
        vllm_config: The vLLM config.
        layer_type: The type of the layer to get.
        layer_names: The names of the layers to get. If None, return all layers.
    """

    if layer_names is None:
887
        layer_names = list(vllm_config.compilation_config.static_forward_context.keys())
888
889
890
891
892
893
894
895

    forward_context = vllm_config.compilation_config.static_forward_context

    return {
        layer_name: forward_context[layer_name]
        for layer_name in layer_names
        if isinstance(forward_context[layer_name], layer_type)
    }