config.py 147 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import ast
4
import copy
5
import enum
6
import hashlib
7
import json
8
import sys
9
import warnings
10
from contextlib import contextmanager
11
from dataclasses import dataclass, field, replace
12
from pathlib import Path
13
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
14
15
                    Final, List, Literal, Mapping, Optional, Protocol, Set,
                    Tuple, Type, Union)
16
17

import torch
18
from pydantic import BaseModel, Field, PrivateAttr
19
from transformers import PretrainedConfig
20

21
import vllm.envs as envs
22
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.logger import init_logger
24
25
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
                                                     get_quantization_config)
26
from vllm.model_executor.models import ModelRegistry
27
from vllm.platforms import CpuArchEnum
28
from vllm.tracing import is_otel_available, otel_import_error_traceback
29
30
31
from vllm.transformers_utils.config import (
    ConfigFormat, get_config, get_hf_image_processor_config,
    get_hf_text_config, get_pooling_config,
32
33
    get_sentence_transformer_tokenizer_config, is_encoder_decoder,
    try_get_generation_config, uses_mrope)
34
from vllm.transformers_utils.s3_utils import S3Model
35
from vllm.transformers_utils.utils import is_s3
36
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
37
                        get_cpu_memory, random_uuid, resolve_obj_by_qualname)
38

39
40
41
if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

42
    from vllm.executor.executor_base import ExecutorBase
43
44
    from vllm.model_executor.layers.quantization.base_config import (
        QuantizationConfig)
45
    from vllm.model_executor.model_loader.loader import BaseModelLoader
46
47
    from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
        BaseTokenizerGroup)
48
49
else:
    QuantizationConfig = None
50

51
52
logger = init_logger(__name__)

53
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
54
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
55

56
57
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
                     "score", "reward"]
58

59
60
61
62
63
64
65
66
67
68
69
70
71
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
                        "draft"]

RunnerType = Literal["generate", "pooling", "draft"]

_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
    "generate": ["generate"],
    "pooling": ["embed", "classify", "score", "reward"],
    "draft": ["draft"],
}

_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
    task: runner
72
73
    for runner, tasks in _RUNNER_TASKS.items()
    for task in tasks
74
}
75

76
77
78
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
                                             PretrainedConfig]]

79

80
81
82
83
84
85
class SupportsHash(Protocol):

    def compute_hash(self) -> str:
        ...


86
87
88
89
90
91
class ModelImpl(str, enum.Enum):
    AUTO = "auto"
    VLLM = "vllm"
    TRANSFORMERS = "transformers"


92
class ModelConfig:
93
94
95
96
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
97
            It is also used as the content for `model_name` tag in metrics
98
99
100
101
102
            output when `served_model_name` is not specified.
        task: The task to use the model for. Each vLLM instance only supports
            one task, even if the same model can be used for multiple tasks.
            When the model only supports one task, "auto" can be used to select
            it; otherwise, you must specify explicitly which task to use.
103
        tokenizer: Name or path of the huggingface tokenizer to use.
104
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
105
106
107
            available, "slow" will always use the slow tokenizer,
            "mistral" will always use the tokenizer from `mistral_common`, and
            "custom" will use --tokenizer to select the preregistered tokenizer.
108
109
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
110
111
112
113
        allowed_local_media_path: Allowing API requests to read local images or
            videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
114
115
116
117
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
118
119
120
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
121
        code_revision: The specific revision to use for the model code on
122
            Hugging Face Hub. It can be a branch name, a tag name, or a
123
            commit id. If unspecified, will use the default version.
124
125
126
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
127
128
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
129
130
        spec_target_max_model_len: Specify the the maximum length for spec
            decoding draft models.
131
132
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
133
134
135
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
136
            If None, the user did not specify, so default to False.
137
138
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
139
140
141
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
142
        max_logprobs: Maximum number of log probabilities. Defaults to 20.
143
144
145
146
        disable_sliding_window: Whether to disable sliding window. If True,
            we will disable the sliding window functionality of the model.
            If the model does not support sliding window, this argument is
            ignored.
147
148
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer.
149
        served_model_name: The model name used in metrics tag `model_name`,
150
151
            matches the model name exposed via the APIs. If multiple model
            names provided, the first name will be used. If not specified,
152
            the model name will be the same as `model`.
153
        limit_mm_per_prompt: Maximum number of data items per modality
154
            per prompt. Only applicable for multimodal models.
155
156
        use_async_output_proc: Whether to use async output processor.
            Defaults to True.
157
158
        config_format: The config format which shall be loaded.
            Defaults to 'auto' which defaults to 'hf'.
159
160
161
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
162
163
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor.
164
165
        disable_mm_preprocessor_cache: If true, then disables caching of the
            multi-modal preprocessor/mapper. (not recommended)
166
167
168
169
        override_neuron_config: Initialize non default neuron config or
            override default neuron config that are specific to Neuron devices,
            this argument will be used to configure the neuron config that
            can not be gathered from the vllm arguments.
170
        override_pooler_config: Initialize non default pooling config or
171
            override default pooling config for the pooling model.
172
173
        logits_processor_pattern: Optional regex pattern specifying valid
            logits processor qualified names that can be passed with the
174
            `logits_processors` extra completion argument. Defaults to None,
175
            which allows no processors.
176
        generation_config: Configuration parameter file for generation.
177
178
179
180
181
182
        model_impl: Which implementation of the model to use:
            "auto" will try to use the vLLM implementation if it exists and
                fall back to the Transformers implementation if no vLLM
                implementation is available.
            "vllm" will use the vLLM model implementation.
            "transformers" will use the Transformers model implementation.
183
184
        override_generation_config: Override the generation config with the
            given config.
185
    """
186

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: List[Any] = []
        factors.append(self.model)
        factors.append(self.dtype)
        factors.append(self.quantization)
        factors.append(self.revision)
        factors.append(self.code_revision)
        factors.append(self.trust_remote_code)
        factors.append(self.rope_scaling)
        factors.append(self.rope_theta)
        return hashlib.sha256(str(factors).encode()).hexdigest()

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    def __init__(
        self,
        model: str,
        task: Union[TaskOption, Literal["draft"]],
        tokenizer: str,
        tokenizer_mode: str,
        trust_remote_code: bool,
        dtype: Union[str, torch.dtype],
        seed: int,
        allowed_local_media_path: str = "",
        revision: Optional[str] = None,
        code_revision: Optional[str] = None,
        rope_scaling: Optional[Dict[str, Any]] = None,
        rope_theta: Optional[float] = None,
        tokenizer_revision: Optional[str] = None,
        max_model_len: Optional[int] = None,
        spec_target_max_model_len: Optional[int] = None,
        quantization: Optional[str] = None,
        enforce_eager: Optional[bool] = None,
        max_seq_len_to_capture: Optional[int] = None,
        max_logprobs: int = 20,
        disable_sliding_window: bool = False,
        skip_tokenizer_init: bool = False,
        served_model_name: Optional[Union[str, List[str]]] = None,
        limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
        use_async_output_proc: bool = True,
        config_format: ConfigFormat = ConfigFormat.AUTO,
        hf_overrides: Optional[HfOverrides] = None,
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
        disable_mm_preprocessor_cache: bool = False,
        override_neuron_config: Optional[Dict[str, Any]] = None,
        override_pooler_config: Optional["PoolerConfig"] = None,
        logits_processor_pattern: Optional[str] = None,
        generation_config: Optional[str] = None,
        enable_sleep_mode: bool = False,
245
        override_generation_config: Optional[Dict[str, Any]] = None,
246
        model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
247
    ) -> None:
248
        self.model = model
249
        self.tokenizer = tokenizer
250
        self.tokenizer_mode = tokenizer_mode
251
        self.trust_remote_code = trust_remote_code
252
        self.allowed_local_media_path = allowed_local_media_path
253
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
254
        self.revision = revision
255
        self.code_revision = code_revision
256
257
        self.rope_scaling = rope_scaling
        self.rope_theta = rope_theta
258
        self.model_impl = model_impl
259
260
261

        if hf_overrides is None:
            hf_overrides = {}
262
263
264
265
266
267

        if callable(hf_overrides):
            hf_overrides_kw = {}
            hf_overrides_fn = hf_overrides
        else:
            hf_overrides_kw = hf_overrides
268
            hf_overrides_fn = None
269

270
271
        if rope_scaling is not None:
            hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
272
            hf_overrides_kw.update(hf_override)
273
274
275
276
277
            msg = ("`--rope-scaling` will be removed in a future release. "
                   f"'Please instead use `--hf-overrides '{hf_override!r}'`")
            warnings.warn(DeprecationWarning(msg), stacklevel=2)
        if rope_theta is not None:
            hf_override = {"rope_theta": rope_theta}
278
            hf_overrides_kw.update(hf_override)
279
280
281
282
            msg = ("`--rope-theta` will be removed in a future release. "
                   f"'Please instead use `--hf-overrides '{hf_override!r}'`")
            warnings.warn(DeprecationWarning(msg), stacklevel=2)

283
284
        self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)

285
286
287
288
289
        # The tokenizer version is consistent with the model version by default.
        if tokenizer_revision is None:
            self.tokenizer_revision = revision
        else:
            self.tokenizer_revision = tokenizer_revision
290
        self.quantization = quantization
291
        self.enforce_eager = enforce_eager
292
        self.max_seq_len_to_capture = max_seq_len_to_capture
293
        self.max_logprobs = max_logprobs
294
        self.disable_sliding_window = disable_sliding_window
295
        self.skip_tokenizer_init = skip_tokenizer_init
296
297
298
299
300
301
        self.enable_sleep_mode = enable_sleep_mode

        from vllm.platforms import current_platform

        if self.enable_sleep_mode and not current_platform.is_cuda():
            raise ValueError("Sleep mode is only supported on CUDA devices.")
302
303

        hf_config = get_config(self.model, trust_remote_code, revision,
304
305
306
307
308
309
310
311
312
                               code_revision, config_format)

        if hf_overrides_kw:
            logger.info("Overriding HF config with %s", hf_overrides_kw)
            hf_config.update(hf_overrides_kw)
        if hf_overrides_fn:
            logger.info("Overriding HF config with %s", hf_overrides_fn)
            hf_config = hf_overrides_fn(hf_config)

313
314
        self.hf_config = hf_config

315
        self.hf_text_config = get_hf_text_config(self.hf_config)
316
        self.encoder_config = self._get_encoder_config()
317
318
        self.hf_image_processor_config = get_hf_image_processor_config(
            self.model, revision)
319
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
320
        self.use_async_output_proc = use_async_output_proc
321
        self.mm_processor_kwargs = mm_processor_kwargs
322
        self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
323

324
325
        # Set enforce_eager to False if the value is unset.
        if self.enforce_eager is None:
326
327
            self.enforce_eager = False

328
329
330
        sliding_window = getattr(self.hf_text_config, "sliding_window", None)
        has_interleaved_attention = (sliding_window is not None) and (
            isinstance(sliding_window, list) or
331
            (self.hf_text_config.model_type in ["gemma2", "cohere2"]))
332
333

        if (not self.disable_sliding_window and has_interleaved_attention):
334
335
            if (backend :=
                    envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
336
337
                sliding_window_len_min = get_min_sliding_window(
                    self.hf_text_config.sliding_window)
338

339
                logger.warning_once(
340
341
                    f"{self.hf_text_config.model_type} has interleaved "
                    "attention, which is currently not supported by the "
342
                    f"{backend} backend. Disabling sliding window and capping "
343
344
345
346
347
348
349
350
351
352
353
354
                    "the max length to the sliding window size "
                    f"({sliding_window_len_min}).")
                self.disable_sliding_window = True
            else:
                # for a model with interleaved attention,
                # the scheduler and the model treat it as full attention
                # (i.e., not dropping any tokens outside the window).
                # only the attention layer itself is aware of the sliding
                # window, and use the window size to compute the attention.
                self.hf_text_config.interleaved_sliding_window = sliding_window
                delattr(self.hf_text_config, "sliding_window")
                sliding_window = None
Woosuk Kwon's avatar
Woosuk Kwon committed
355

356
357
358
359
        self.max_model_len = _get_and_verify_max_len(
            hf_config=self.hf_text_config,
            max_model_len=max_model_len,
            disable_sliding_window=self.disable_sliding_window,
360
            sliding_window_len=self.get_hf_config_sliding_window(),
361
362
            spec_target_max_model_len=spec_target_max_model_len,
            encoder_config=self.encoder_config)
363
364
        self.served_model_name = get_served_model_name(model,
                                                       served_model_name)
365
366
        self.multimodal_config = self._init_multimodal_config(
            limit_mm_per_prompt)
367
368
        if not self.skip_tokenizer_init:
            self._verify_tokenizer_mode()
369

370
        self.is_attention_free = self._init_attention_free()
371
        self.is_hybrid = self._init_is_hybrid()
372
373
        self.has_inner_state = self._init_has_inner_state()

374
375
376
377
        if current_platform.is_neuron():
            self.override_neuron_config = override_neuron_config
        else:
            self.override_neuron_config = None
378
379
380
381

        supported_tasks, task = self._resolve_task(task, self.hf_config)
        self.supported_tasks = supported_tasks
        self.task: Final = task
382
383
384
385
        if self.task in ("draft", "generate"):
            self.truncation_side = "left"
        else:
            self.truncation_side = "right"
386

387
        self.pooler_config = self._init_pooler_config(override_pooler_config)
388
        self.logits_processor_pattern = logits_processor_pattern
389

390
        self.generation_config = generation_config
391
        self.override_generation_config = override_generation_config or {}
392

393
        self._verify_quantization()
394
        self._verify_cuda_graph()
395
        self._verify_bnb_config()
396

397
398
399
    def maybe_pull_model_tokenizer_for_s3(self, model: str,
                                          tokenizer: str) -> None:
        """
400
        Pull the model config or tokenizer to a temporary
401
402
403
404
405
406
407
408
409
        directory in case of S3.

        Args:
            model: The model name or path.
            tokenizer: The tokenizer name or path.

        """
        if is_s3(model) or is_s3(tokenizer):
            if is_s3(model):
410
411
                s3_model = S3Model()
                s3_model.pull_files(model, allow_pattern=["*config.json"])
412
                self.model_weights = self.model
413
                self.model = s3_model.dir
414
415

            if is_s3(tokenizer):
416
417
                s3_tokenizer = S3Model()
                s3_tokenizer.pull_files(
418
                    model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
419
                self.tokenizer = s3_tokenizer.dir
420

421
422
423
424
    def _init_multimodal_config(
        self, limit_mm_per_prompt: Optional[Mapping[str, int]]
    ) -> Optional["MultiModalConfig"]:
        architectures = getattr(self.hf_config, "architectures", [])
425
        if ModelRegistry.is_multimodal_model(architectures):
426
            return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
427
428
429
430
431
432

        if limit_mm_per_prompt:
            raise ValueError("`limit_mm_per_prompt` is only supported for "
                             "multimodal models.")

        return None
433

434
435
436
437
    def _get_encoder_config(self):
        return get_sentence_transformer_tokenizer_config(
            self.model, self.revision)

438
439
    def _init_pooler_config(
        self,
440
        override_pooler_config: Optional["PoolerConfig"],
441
    ) -> Optional["PoolerConfig"]:
442

443
        if self.runner_type == "pooling":
444
445
446
447
448
449
450
451
452
453
454
            user_config = override_pooler_config or PoolerConfig()

            base_config = get_pooling_config(self.model, self.revision)
            if base_config is not None:
                # Only set values that are not overridden by the user
                for k, v in base_config.items():
                    if getattr(user_config, k) is None:
                        setattr(user_config, k, v)

            return user_config

455
456
        return None

457
458
459
460
    def _init_attention_free(self) -> bool:
        architectures = getattr(self.hf_config, "architectures", [])
        return ModelRegistry.is_attention_free_model(architectures)

461
462
463
464
    def _init_is_hybrid(self) -> bool:
        architectures = getattr(self.hf_config, "architectures", [])
        return ModelRegistry.is_hybrid_model(architectures)

465
466
467
468
    def _init_has_inner_state(self) -> bool:
        architectures = getattr(self.hf_config, "architectures", [])
        return ModelRegistry.model_has_inner_state(architectures)

469
470
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
471
        if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
472
473
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
474
                "either 'auto', 'slow', 'mistral' or 'custom'.")
475
        self.tokenizer_mode = tokenizer_mode
476

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    def _get_preferred_task(
        self,
        architectures: List[str],
        supported_tasks: Set[_ResolvedTask],
    ) -> Optional[_ResolvedTask]:
        model_id = self.model
        if get_pooling_config(model_id, self.revision):
            return "embed"
        if ModelRegistry.is_cross_encoder_model(architectures):
            return "score"

        suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
            # Other models follow this pattern
            ("ForCausalLM", "generate"),
            ("ForConditionalGeneration", "generate"),
            ("ForSequenceClassification", "classify"),
            ("ChatModel", "generate"),
            ("LMHeadModel", "generate"),
            ("EmbeddingModel", "embed"),
            ("RewardModel", "reward"),
        ]
        _, arch = ModelRegistry.inspect_model_cls(architectures)

        for suffix, pref_task in suffix_to_preferred_task:
            if arch.endswith(suffix) and pref_task in supported_tasks:
                return pref_task

        return None

506
507
    def _resolve_task(
        self,
508
        task_option: Union[TaskOption, Literal["draft"]],
509
        hf_config: PretrainedConfig,
510
    ) -> Tuple[Set[_ResolvedTask], _ResolvedTask]:
511
512
513
        if task_option == "draft":
            return {"draft"}, "draft"

514
515
        architectures = getattr(hf_config, "architectures", [])

516
        runner_support: Dict[RunnerType, bool] = {
517
518
519
            # NOTE: Listed from highest to lowest priority,
            # in case the model supports multiple of them
            "generate": ModelRegistry.is_text_generation_model(architectures),
520
            "pooling": ModelRegistry.is_pooling_model(architectures),
521
        }
522
523
524
525
526
527
528
529
530
        supported_runner_types_lst: List[RunnerType] = [
            runner_type
            for runner_type, is_supported in runner_support.items()
            if is_supported
        ]

        supported_tasks_lst: List[_ResolvedTask] = [
            task for runner_type in supported_runner_types_lst
            for task in _RUNNER_TASKS[runner_type]
531
532
533
534
535
        ]
        supported_tasks = set(supported_tasks_lst)

        if task_option == "auto":
            selected_task = next(iter(supported_tasks_lst))
536

537
538
539
540
541
            if len(supported_tasks_lst) > 1:
                preferred_task = self._get_preferred_task(
                    architectures, supported_tasks)
                if preferred_task is not None:
                    selected_task = preferred_task
542

543
544
545
                logger.info(
                    "This model supports multiple tasks: %s. "
                    "Defaulting to '%s'.", supported_tasks, selected_task)
546
        else:
547
548
549
550
551
552
553
554
555
556
557
558
559
560
            # Aliases
            if task_option == "embedding":
                preferred_task = self._get_preferred_task(
                    architectures, supported_tasks)
                if preferred_task != "embed":
                    msg = ("The 'embedding' task will be restricted to "
                           "embedding models in a future release. Please "
                           "pass `--task classify`, `--task score`, or "
                           "`--task reward` explicitly for other pooling "
                           "models.")
                    warnings.warn(msg, DeprecationWarning, stacklevel=2)

                task_option = preferred_task or "embed"

561
562
563
564
565
566
567
            if task_option not in supported_tasks:
                msg = (
                    f"This model does not support the '{task_option}' task. "
                    f"Supported tasks: {supported_tasks}")
                raise ValueError(msg)

            selected_task = task_option
568

569
        return supported_tasks, selected_task
570

571
572
573
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
574
            # compressed-tensors uses a "compression_config" key
575
            quant_cfg = getattr(self.hf_config, "compression_config", None)
576
577
        return quant_cfg

578
    def _verify_quantization(self) -> None:
579
        supported_quantization = QUANTIZATION_METHODS
580
        optimized_quantization_methods = [
581
582
            "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
            "awq_marlin", "fbgemm_fp8", "compressed_tensors",
583
            "compressed-tensors", "experts_int8", "quark"
584
        ]
585
586
587
588
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
589
590
        quant_cfg = self._parse_quant_hf_config()

591
592
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
593
594

            # Detect which checkpoint is it
595
596
            for name in QUANTIZATION_METHODS:
                method = get_quantization_config(name)
597
598
599
600
601
602
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization)
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break
603

604
            # Verify quantization configurations.
605
            if self.quantization is None:
606
607
                self.quantization = quant_method
            elif self.quantization != quant_method:
608
609
                raise ValueError(
                    "Quantization method specified in the model config "
610
                    f"({quant_method}) does not match the quantization "
611
612
613
614
615
616
617
618
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
619
            from vllm.platforms import current_platform
620
            current_platform.verify_quantization(self.quantization)
621
            if self.quantization not in optimized_quantization_methods:
622
                logger.warning(
623
                    "%s quantization is not fully "
624
                    "optimized yet. The speed can be slower than "
625
                    "non-quantized models.", self.quantization)
626

627
    def _verify_cuda_graph(self) -> None:
628
629
630
631
        if self.max_seq_len_to_capture is None:
            self.max_seq_len_to_capture = self.max_model_len
        self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
                                          self.max_model_len)
632

633
        MODEL_NOT_SUPPORT_CUDA_GRAPH = ['mllama']
634
        if (self.hf_config.model_type in MODEL_NOT_SUPPORT_CUDA_GRAPH
Simon Mo's avatar
Simon Mo committed
635
                and not self.enforce_eager):
636
637
638
            logger.warning(
                "CUDA graph is not supported for %s yet, fallback to the eager "
                "mode.", self.hf_config.model_type)
Simon Mo's avatar
Simon Mo committed
639
640
            self.enforce_eager = True

641
642
    def _verify_bnb_config(self) -> None:
        """
643
        The current version of bitsandbytes (0.44.0) with 8-bit models does not
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
        yet support CUDA graph.
        """
        is_bitsandbytes = self.quantization == "bitsandbytes"
        has_quantization_config = (getattr(self.hf_config,
                                           "quantization_config", None)
                                   is not None)
        is_8bit = (self.hf_config.quantization_config.get(
            "load_in_8bit", False) if has_quantization_config else False)
        if all([
                is_bitsandbytes,
                has_quantization_config,
                is_8bit,
                not self.enforce_eager,
        ]):
            logger.warning(
                "CUDA graph is not supported on BitAndBytes 8bit yet, "
                "fallback to the eager mode.")
            self.enforce_eager = True

663
664
665
666
667
668
669
670
671
672
673
674
    def verify_async_output_proc(self, parallel_config, speculative_config,
                                 device_config) -> None:
        if not self.use_async_output_proc:
            # Nothing to check
            return

        if parallel_config.pipeline_parallel_size > 1:
            logger.warning("Async output processing can not be enabled "
                           "with pipeline parallel")
            self.use_async_output_proc = False
            return

675
        # Reminder: Please update docs/source/features/compatibility_matrix.md
676
        # If the feature combo become valid
677
        from vllm.platforms import current_platform
678
        if not current_platform.is_async_output_supported(self.enforce_eager):
679
            logger.warning(
680
681
                "Async output processing is not supported on the "
                "current platform type %s.", current_platform.device_type)
682
683
684
685
686
687
688
689
690
            self.use_async_output_proc = False
            return

        if envs.VLLM_USE_RAY_SPMD_WORKER:
            logger.warning(
                "Async output processing can not be enabled with ray spmd")
            self.use_async_output_proc = False
            return

691
        # Async postprocessor is not necessary for pooling models
692
        # since there is no token generation
693
        if self.runner_type == "pooling":
694
695
            self.use_async_output_proc = False

696
        # Reminder: Please update docs/source/features/compatibility_matrix.md
697
        # If the feature combo become valid
698
699
700
701
702
        if speculative_config:
            logger.warning("Async output processing is not supported with"
                           " speculative decoding currently.")
            self.use_async_output_proc = False

703
704
705
706
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
707
708
        total_num_attention_heads = getattr(self.hf_text_config,
                                            "num_attention_heads", 0)
709
710
711
712
713
714
715
716
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        pipeline_parallel_size = parallel_config.pipeline_parallel_size
717
718
719
720
721
722
723
724
725
726
727
        if pipeline_parallel_size > 1:
            architectures = getattr(self.hf_config, "architectures", [])
            if not ModelRegistry.is_pp_supported_model(architectures):
                raise NotImplementedError(
                    "Pipeline parallelism is not supported for this model. "
                    "Supported models implement the `SupportsPP` interface.")

            if self.use_async_output_proc:
                logger.warning("Async output processor is not supported with "
                               "pipeline parallelism currently. Disabling it.")
                self.use_async_output_proc = False
728

729
730
    def get_hf_config_sliding_window(
            self) -> Union[Optional[int], List[Optional[int]]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
731
        """Get the sliding window size, or None if disabled."""
732
733
734
735

        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
        # addition to sliding window size. We check if that field is present
        # and if it's False, return None.
736
737
        if (hasattr(self.hf_text_config, "use_sliding_window")
                and not self.hf_text_config.use_sliding_window):
738
            return None
739
        return getattr(self.hf_text_config, "sliding_window", None)
740

741
    def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
742
743
744
745
746
747
748
749
        """Get the sliding window size, or None if disabled.
        """
        # If user disables sliding window, return None.
        if self.disable_sliding_window:
            return None
        # Otherwise get the value from the hf config.
        return self.get_hf_config_sliding_window()

750
    def get_vocab_size(self) -> int:
751
        return self.hf_text_config.vocab_size
752

753
    def get_hidden_size(self) -> int:
754
        return self.hf_text_config.hidden_size
755

756
757
    @property
    def is_deepseek_mla(self) -> bool:
758
759
760
761
        return (hasattr(self.hf_text_config, "model_type")) \
                and (self.hf_text_config.model_type in \
                    ('deepseek_v2', 'deepseek_v3'))\
                and (self.hf_text_config.kv_lora_rank is not None)
762

763
    def get_head_size(self) -> int:
wangding zeng's avatar
wangding zeng committed
764
        # TODO remove hard code
765
        if self.is_deepseek_mla:
766
767
            qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
                                       0)
768
            if self.use_mla:
769
                return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
770
771
772
773
774
            else:
                qk_nope_head_dim = getattr(self.hf_text_config,
                                           "qk_nope_head_dim", 0)
                if qk_rope_head_dim and qk_nope_head_dim:
                    return qk_rope_head_dim + qk_nope_head_dim
775
776
777
778

        if self.is_attention_free:
            return 0

779
780
        if hasattr(self.hf_text_config, "head_dim"):
            return self.hf_text_config.head_dim
781
        # FIXME(woosuk): This may not be true for all models.
782
783
        return (self.hf_text_config.hidden_size //
                self.hf_text_config.num_attention_heads)
784

785
786
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
787
        # For GPTBigCode & Falcon:
788
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
789
790
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
791
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
792
        new_decoder_arch_falcon = (
793
            self.hf_config.model_type in falcon_model_types
794
            and getattr(self.hf_config, "new_decoder_architecture", False))
795
        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
796
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
797
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
798
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
799
            return 1
800

801
        # For DBRX and MPT
802
803
804
805
806
        if self.hf_config.model_type == "mpt":
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type == "dbrx":
807
808
809
            return getattr(self.hf_config.attn_config, "kv_n_heads",
                           self.hf_config.num_attention_heads)

810
811
812
        if self.is_attention_free:
            return 0

813
814
815
816
817
818
819
820
821
822
        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
823
            num_kv_heads = getattr(self.hf_text_config, attr, None)
824
825
826
827
828
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
829
        return self.hf_text_config.num_attention_heads
830
831
832

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
833
834
835
836
        if self.use_mla:
            # When using MLA during decode it becomes MQA
            return 1

837
838
839
840
841
842
843
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
844

845
846
    def get_num_attention_heads(self,
                                parallel_config: "ParallelConfig") -> int:
847
848
        num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
        return num_heads // parallel_config.tensor_parallel_size
849

850
851
    def get_layers_start_end_indices(
            self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
852
        from vllm.distributed.utils import get_pp_indices
Mor Zusman's avatar
Mor Zusman committed
853
854
        total_num_hidden_layers = getattr(self.hf_text_config,
                                          "num_hidden_layers", 0)
855
856
857
        pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
        pp_size = parallel_config.pipeline_parallel_size
        start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
858
        return start, end
Mor Zusman's avatar
Mor Zusman committed
859

860
861
862
    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
        start, end = self.get_layers_start_end_indices(parallel_config)
        return end - start
Mor Zusman's avatar
Mor Zusman committed
863

864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
    def get_num_layers_by_block_type(
        self,
        parallel_config: "ParallelConfig",
        block_type: LayerBlockType = LayerBlockType.attention,
    ) -> int:
        # This function relies on 'layers_block_type' in hf_config,
        # for w/o this attribute, we will need to have workarounds like so
        attn_block_type = block_type == LayerBlockType.attention
        is_transformer = not self.is_hybrid and not self.is_attention_free
        start, end = self.get_layers_start_end_indices(parallel_config)

        if is_transformer:
            # Handle the basic case first
            return end - start if attn_block_type else 0
        elif self.is_attention_free:
            # Attention free
            # Note that this code assumes there
            # is only one type of attention-free block type.
            return 0 if attn_block_type else end - start
        else:
            # Hybrid model
            layers_block_type_value = getattr(self.hf_config,
                                              "layers_block_type", None)
            if layers_block_type_value is None:
                raise ValueError("The model is an hybrid without a"
                                 "layers_block_type in the hf_config,"
                                 "cannot determine the num of "
                                 f"{block_type.value} layers")

            return sum(t == block_type.value
                       for t in layers_block_type_value[start:end])
Mor Zusman's avatar
Mor Zusman committed
895

896
897
898
899
900
901
902
903
904
905
906
907
    def get_multimodal_config(self) -> "MultiModalConfig":
        """
        Get the multimodal configuration of the model.

        Raises:
            ValueError: If the model is not multimodal.
        """
        if self.multimodal_config is None:
            raise ValueError("The model is not multimodal.")

        return self.multimodal_config

908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
    def try_get_generation_config(self) -> Dict[str, Any]:
        if self.generation_config is None or self.generation_config == "auto":
            config = try_get_generation_config(
                self.model,
                trust_remote_code=self.trust_remote_code,
                revision=self.revision,
            )
        else:
            config = try_get_generation_config(
                self.generation_config,
                trust_remote_code=self.trust_remote_code,
            )

        if config is None:
            return {}

        return config.to_diff_dict()

    def get_diff_sampling_param(self) -> Dict[str, Any]:
        """
928
929
930
        This method returns a dictionary containing the parameters
        that differ from the default sampling parameters, but only
        if `generation_config` is set. If `generation_config` is not
931
932
933
        set, an empty dictionary is returned.

        Returns:
934
935
            Dict[str, Any]: A dictionary with the differing sampling
            parameters if `generation_config` is set, otherwise an
936
937
938
939
            empty dictionary.
        """
        if self.generation_config is None:
            # When generation_config is not set
940
941
942
943
944
945
946
            config = {}
        else:
            config = self.try_get_generation_config()

        # Overriding with given generation config
        config.update(self.override_generation_config)

947
948
949
950
951
952
        available_params = [
            "repetition_penalty",
            "temperature",
            "top_k",
            "top_p",
            "min_p",
953
            "max_new_tokens",
954
955
956
957
958
959
        ]
        if any(p in config for p in available_params):
            diff_sampling_param = {
                p: config.get(p)
                for p in available_params if config.get(p) is not None
            }
960
961
962
963
964
            # Huggingface definition of max_new_tokens is equivalent
            # to vLLM's max_tokens
            if "max_new_tokens" in diff_sampling_param:
                diff_sampling_param["max_tokens"] = diff_sampling_param.pop(
                    "max_new_tokens")
965
966
967
968
        else:
            diff_sampling_param = {}
        return diff_sampling_param

969
    @property
970
    def is_encoder_decoder(self) -> bool:
971
        """Extract the HF encoder/decoder model flag."""
972
973
974
975
976
        return is_encoder_decoder(self.hf_config)

    @property
    def uses_mrope(self) -> bool:
        return uses_mrope(self.hf_config)
977

978
979
980
981
    @property
    def is_multimodal_model(self) -> bool:
        return self.multimodal_config is not None

982
983
984
985
986
    @property
    def is_cross_encoder(self) -> bool:
        architectures = getattr(self.hf_config, "architectures", [])
        return ModelRegistry.is_cross_encoder_model(architectures)

987
988
    @property
    def use_mla(self) -> bool:
989
990
991
        if not self.is_deepseek_mla or envs.VLLM_MLA_DISABLE:
            return False

992
993
994
995
996
997
998
999
1000
1001
1002
        if self.quantization is not None and self.quantization not in [\
            "fp8", "compressed-tensors"]:
            logger.warning(
                "MLA is not supported with %s quantization. "
                "Disabling MLA.", self.quantization)
            return False

        # If using a "compressed-tensors" checkpoint, check that all groups
        # have fp8 for both weights and activations.
        if self.quantization == "compressed-tensors":
            quant_config = self._parse_quant_hf_config()
1003
1004
1005
            for group_name, cfg in quant_config.get("config_groups", {
                    "": {}
            }).items():
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
                act_cfg = cfg.get("input_activations", {})
                act_type = None if act_cfg is None else act_cfg.get("type", "")
                w_cfg = cfg.get("weights", {})
                w_type = None if w_cfg is None else w_cfg.get("type", "")
                if act_type != "fp8" or w_type != "fp8":
                    logger.warning(
                        "compressed-tensors MLA support requires fp8 "
                        "activations and weights in group '%s', but got "
                        "activations type '%s' and weights type '%s'.\n "
                        "Full config: %s", group_name, act_type, w_type,
                        quant_config)
                    return False

1019
        return True
1020

1021
1022
1023
1024
1025
1026
1027
1028
    @property
    def supported_runner_types(self) -> Set[RunnerType]:
        return {_TASK_RUNNER[task] for task in self.supported_tasks}

    @property
    def runner_type(self) -> RunnerType:
        return _TASK_RUNNER[self.task]

1029
1030

class CacheConfig:
1031
1032
1033
1034
1035
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
1036
            vLLM execution.
1037
        swap_space: Size of the CPU swap space per GPU (in GiB).
1038
        cache_dtype: Data type for kv cache storage.
1039
        is_attention_free: Whether the model is attention-free.
1040
        num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
1041
            profiled num_gpu_blocks if specified. Does nothing if None.
1042
1043
1044
1045
        sliding_window: Sliding window size for the KV cache. Can not work with
            prefix caching enabled.
        enable_prefix_caching: Whether to enable prefix caching.
        cpu_offload_gb: Size of the CPU offload buffer in GiB.
1046
    """
1047

1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: List[Any] = []
        factors.append(self.cache_dtype)
        # `cpu_offload_gb` does not use `torch.compile` yet.
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

1066
1067
1068
1069
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
1070
        swap_space: float,
1071
        cache_dtype: str,
1072
        is_attention_free: bool = False,
1073
        num_gpu_blocks_override: Optional[int] = None,
1074
        sliding_window: Optional[int] = None,
1075
        enable_prefix_caching: bool = False,
1076
        cpu_offload_gb: float = 0,
1077
        calculate_kv_scales: Optional[bool] = None,
1078
1079
1080
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
1081
        self.swap_space_bytes = swap_space * GiB_bytes
1082
        self.num_gpu_blocks_override = num_gpu_blocks_override
1083
        self.cache_dtype = cache_dtype
1084
        self.is_attention_free = is_attention_free
1085
        self.sliding_window = sliding_window
1086
        self.enable_prefix_caching = enable_prefix_caching
1087
        self.cpu_offload_gb = cpu_offload_gb
1088
        self.calculate_kv_scales = calculate_kv_scales
1089
        self._verify_args()
1090
        self._verify_cache_dtype()
1091
        self._verify_prefix_caching()
1092
1093

        # Will be set after profiling.
1094
1095
        self.num_gpu_blocks: Optional[int] = None
        self.num_cpu_blocks: Optional[int] = None
1096

1097
1098
1099
1100
        # Set calculate_kv_scales to False if the value is unset.
        if self.calculate_kv_scales is None:
            self.calculate_kv_scales = False

1101
    def metrics_info(self):
1102
1103
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
1104
1105
        return {key: str(value) for key, value in self.__dict__.items()}

1106
1107
1108
1109
1110
1111
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

1112
1113
1114
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
1115
        elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
1116
            logger.info(
1117
1118
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
1119
1120
                "Meanwhile, it may cause accuracy drop without a proper "
                "scaling factor")
1121
1122
1123
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

1124
1125
1126
1127
1128
1129
1130
1131
1132
    def _verify_prefix_caching(self) -> None:
        if not self.enable_prefix_caching:
            return

        if self.sliding_window is not None:
            raise NotImplementedError(
                "Prefix caching is not supported with sliding window. "
                "Run with --disable-sliding-window to use prefix caching.")

1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

1143
1144
1145
        msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
               f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
               "is allocated for the swap space.")
1146
1147
1148
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
1149
            logger.warning("Possibly too large swap space. %s", msg)
1150

1151

1152
1153
1154
@dataclass
class TokenizerPoolConfig:
    """Configuration for the tokenizer pool.
1155

1156
1157
1158
1159
1160
1161
1162
1163
    Args:
        pool_size: Number of tokenizer workers in the pool.
        pool_type: Type of the pool.
        extra_config: Additional config for the pool.
            The way the config will be used depends on the
            pool type.
    """
    pool_size: int
1164
    pool_type: Union[str, Type["BaseTokenizerGroup"]]
1165
1166
    extra_config: dict

1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

1185
    def __post_init__(self):
1186
1187
        if self.pool_type not in ("ray", ) and not isinstance(
                self.pool_type, type):
1188
1189
1190
1191
1192
1193
            raise ValueError(f"Unknown pool type: {self.pool_type}")
        if not isinstance(self.extra_config, dict):
            raise ValueError("extra_config must be a dictionary.")

    @classmethod
    def create_config(
1194
1195
        cls, tokenizer_pool_size: int,
        tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]],
1196
1197
1198
        tokenizer_pool_extra_config: Optional[Union[str, dict]]
    ) -> Optional["TokenizerPoolConfig"]:
        """Create a TokenizerPoolConfig from the given parameters.
1199

1200
        If tokenizer_pool_size is 0, return None.
1201

1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        Args:
            tokenizer_pool_size: Number of tokenizer workers in the pool.
            tokenizer_pool_type: Type of the pool.
            tokenizer_pool_extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type. This can be a JSON string (will be parsed).
        """
        if tokenizer_pool_size:
            if isinstance(tokenizer_pool_extra_config, str):
                tokenizer_pool_extra_config_parsed = json.loads(
                    tokenizer_pool_extra_config)
            else:
                tokenizer_pool_extra_config_parsed = (
                    tokenizer_pool_extra_config or {})
            tokenizer_pool_config = cls(tokenizer_pool_size,
                                        tokenizer_pool_type,
                                        tokenizer_pool_extra_config_parsed)
        else:
            tokenizer_pool_config = None
        return tokenizer_pool_config


1224
1225
1226
1227
1228
1229
1230
class LoadFormat(str, enum.Enum):
    AUTO = "auto"
    PT = "pt"
    SAFETENSORS = "safetensors"
    NPCACHE = "npcache"
    DUMMY = "dummy"
    TENSORIZER = "tensorizer"
1231
    SHARDED_STATE = "sharded_state"
1232
    GGUF = "gguf"
1233
    BITSANDBYTES = "bitsandbytes"
1234
    MISTRAL = "mistral"
1235
    RUNAI_STREAMER = "runai_streamer"
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254


@dataclass
class LoadConfig:
    """
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
            "tensorizer" will use CoreWeave's tensorizer library for
                fast weight loading.
1255
            "bitsandbytes" will load nf4 type weights.
1256
        model_loader_extra_config: The extra config for the model loader.
1257
        ignore_patterns: The list of patterns to ignore when loading the model.
1258
            Default to "original/**/*" to avoid repeated loading of llama's
1259
            checkpoints.
1260
1261
1262
1263
1264
1265
    """

    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(
        default_factory=dict)
1266
    ignore_patterns: Optional[Union[List[str], str]] = None
1267

1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

1286
1287
1288
1289
1290
    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = json.loads(
                model_loader_extra_config)
1291
1292
1293
        if isinstance(self.load_format, str):
            load_format = self.load_format.lower()
            self.load_format = LoadFormat(load_format)
1294

1295
1296
1297
1298
1299
1300
1301
        if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
            logger.info(
                "Ignoring the following patterns when downloading weights: %s",
                self.ignore_patterns)
        else:
            self.ignore_patterns = ["original/**/*"]

1302

1303
@dataclass
1304
class ParallelConfig:
1305
    """Configuration for the distributed execution."""
1306

1307
1308
    pipeline_parallel_size: int = 1  # Number of pipeline parallel groups.
    tensor_parallel_size: int = 1  # Number of tensor parallel groups.
1309

1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
    # Maximum number of multiple batches
    # when load model sequentially. To avoid RAM OOM when using tensor
    # parallel and large models.
    max_parallel_loading_workers: Optional[int] = None

    # Disable the custom all-reduce kernel and fall back to NCCL.
    disable_custom_all_reduce: bool = False

    # Config for the tokenizer pool. If None, will use synchronous tokenization.
    tokenizer_pool_config: Optional[TokenizerPoolConfig] = None

    # Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
    ray_workers_use_nsight: bool = False

    # ray distributed model workers placement group.
    placement_group: Optional["PlacementGroup"] = None

    # Backend to use for distributed model
    # workers, either "ray" or "mp" (multiprocessing). If the product
    # of pipeline_parallel_size and tensor_parallel_size is less than
    # or equal to the number of GPUs available, "mp" will be used to
    # keep processing on a single host. Otherwise, this will default
    # to "ray" if Ray is installed and fail otherwise. Note that tpu
    # and hpu only support Ray for distributed inference.
    distributed_executor_backend: Optional[Union[str,
                                                 Type["ExecutorBase"]]] = None

    # the full name of the worker class to use. If "auto", the worker class
    # will be determined based on the platform.
    worker_cls: str = "auto"
1340
    sd_worker_cls: str = "auto"
1341
1342
1343
1344
1345

    world_size: int = field(init=False)

    rank: int = 0

1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
    def compute_hash(self):
        """
        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: List[Any] = []
        factors.append(self.pipeline_parallel_size)
        factors.append(self.tensor_parallel_size)
        return hashlib.sha256(str(factors).encode()).hexdigest()

1359
1360
1361
1362
    def __post_init__(self) -> None:
        self.world_size = self.pipeline_parallel_size * \
            self.tensor_parallel_size

1363
        ray_only_devices = ["tpu"]
1364
        from vllm.platforms import current_platform
1365
1366
        if (current_platform.device_type in ray_only_devices
                and self.world_size > 1):
1367
1368
1369
1370
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
            if self.distributed_executor_backend != "ray":
                raise ValueError(
1371
1372
                    f"{current_platform.device_type.upper()} backend only "
                    "supports Ray for distributed inference.")
1373

1374
        if self.distributed_executor_backend is None and self.world_size > 1:
1375
1376
1377
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

1378
            from vllm.executor import ray_utils
1379
            backend = "mp"
1380
            ray_found = ray_utils.ray_is_available()
1381
1382
1383
1384
1385
            if current_platform.is_neuron():
                # neuron uses single process to control multiple devices
                backend = "uni"
            elif (current_platform.is_cuda()
                  and cuda_device_count_stateless() < self.world_size):
1386
1387
                if not ray_found:
                    raise ValueError("Unable to load Ray which is "
1388
1389
1390
                                     "required for multi-node inference, "
                                     "please install Ray with `pip install "
                                     "ray`.") from ray_utils.ray_import_err
1391
1392
                backend = "ray"
            elif ray_found:
1393
                if self.placement_group:
1394
                    backend = "ray"
1395
1396
1397
1398
1399
1400
                else:
                    from ray import is_initialized as ray_is_initialized
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
                        if get_current_placement_group():
                            backend = "ray"
1401
1402
1403
            self.distributed_executor_backend = backend
            logger.info("Defaulting to use %s for distributed inference",
                        backend)
1404

1405
1406
1407
        if self.distributed_executor_backend is None and self.world_size == 1:
            self.distributed_executor_backend = "uni"

1408
1409
        self._verify_args()

1410
1411
1412
1413
1414
1415
    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
            and self.distributed_executor_backend.uses_ray)

1416
    def _verify_args(self) -> None:
1417
1418
        # Lazy import to avoid circular import
        from vllm.executor.executor_base import ExecutorBase
1419
        from vllm.platforms import current_platform
1420
        if self.distributed_executor_backend not in (
1421
1422
                "ray", "mp", "uni",
                "external_launcher", None) and not (isinstance(
1423
1424
                    self.distributed_executor_backend, type) and issubclass(
                        self.distributed_executor_backend, ExecutorBase)):
1425
            raise ValueError(
1426
1427
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
1428
1429
                "values are 'ray', 'mp' 'uni', 'external_launcher' or"
                " custom ExecutorBase subclass.")
1430
        if self.use_ray:
1431
1432
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()
1433
        if current_platform.is_rocm():
1434
1435
1436
1437
            self.disable_custom_all_reduce = True
            logger.info(
                "Disabled the custom all-reduce kernel because it is not "
                "supported on AMD GPUs.")
1438
        if self.ray_workers_use_nsight and not self.use_ray:
1439
1440
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
1441

1442

1443
@dataclass
1444
class SchedulerConfig:
1445
    """Scheduler configuration."""
1446

1447
    runner_type: str = "generate"  # The runner type to launch for the model.
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472

    # Maximum number of tokens to be processed in a single iteration.
    max_num_batched_tokens: int = field(default=None)  # type: ignore

    # Maximum number of sequences to be processed in a single iteration.
    max_num_seqs: int = 128

    # Maximum length of a sequence (including prompt and generated text).
    max_model_len: int = 8192

    # The number of slots to allocate per sequence per
    # step, beyond the known token ids. This is used in speculative
    # decoding to store KV activations of tokens which may or may not be
    # accepted.
    num_lookahead_slots: int = 0

    # Apply a delay (of delay factor multiplied by previous
    # prompt latency) before scheduling next prompt.
    delay_factor: float = 0.0

    # If True, prefill requests can be chunked based
    # on the remaining max_num_batched_tokens.
    enable_chunked_prefill: bool = False

    is_multimodal_model: bool = False
1473

1474
1475
1476
1477
1478
1479
    # NOTE: The following multimodal encoder budget will be initialized to
    # max_num_batched_tokens and overridden in case max multimodal embedding
    # size is larger.
    # TODO (ywang96): Make these configurable.
    # Multimodal encoder compute budget, only used in V1
    max_num_encoder_input_tokens: int = field(default=None)  # type: ignore
1480
1481

    # Multimodal encoder cache size, only used in V1
1482
    encoder_cache_size: int = field(default=None)  # type: ignore
1483

1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
    # Whether to perform preemption by swapping or
    # recomputation. If not specified, we determine the mode as follows:
    # We use recomputation by default since it incurs lower overhead than
    # swapping. However, when the sequence group has multiple sequences
    # (e.g., beam search), recomputation is not currently supported. In
    # such a case, we use swapping instead.
    preemption_mode: Optional[str] = None

    num_scheduler_steps: int = 1

    multi_step_stream_outputs: bool = False

    # Private API. If used, scheduler sends delta data to
    # workers instead of an entire data. It should be enabled only
    # when SPMD worker architecture is enabled. I.e.,
    # VLLM_USE_RAY_SPMD_WORKER=1
    send_delta_data: bool = False

    # The scheduling policy to use. "fcfs" (default) or "priority".
    policy: str = "fcfs"

    chunked_prefill_enabled: bool = field(init=False)

1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

1525
1526
1527
1528
    def __post_init__(self) -> None:
        if self.max_num_batched_tokens is None:
            if self.enable_chunked_prefill:
                if self.num_scheduler_steps > 1:
1529
1530
1531
1532
                    # Multi-step Chunked-Prefill doesn't allow prompt-chunking
                    # for now. Have max_num_batched_tokens set to max_model_len
                    # so we don't reject sequences on account of a short
                    # max_num_batched_tokens.
1533
                    self.max_num_batched_tokens = max(self.max_model_len, 2048)
1534
                else:
1535
1536
1537
                    # This value is chosen to have a balance between ITL
                    # and TTFT. Note it is not optimized for throughput.
                    self.max_num_batched_tokens = 2048
1538
1539
1540
            else:
                # If max_model_len is too short, use 2048 as the default value
                # for higher throughput.
1541
                self.max_num_batched_tokens = max(self.max_model_len, 2048)
1542

1543
1544
            if self.runner_type == "pooling":
                # Choose specific value for higher throughput
1545
1546
                self.max_num_batched_tokens = max(
                    self.max_num_batched_tokens,
1547
                    _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
1548
                )
1549
            if self.is_multimodal_model:
1550
                # The value needs to be at least the number of multimodal tokens
1551
1552
                self.max_num_batched_tokens = max(
                    self.max_num_batched_tokens,
1553
1554
1555
                    _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
                )

1556
1557
1558
        self.max_num_encoder_input_tokens = self.max_num_batched_tokens
        self.encoder_cache_size = self.max_num_batched_tokens

1559
        if self.enable_chunked_prefill:
1560
1561
            logger.info(
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
1562
                self.max_num_batched_tokens)
1563

1564
        self.chunked_prefill_enabled = self.enable_chunked_prefill
1565
1566
1567
        self._verify_args()

    def _verify_args(self) -> None:
1568
1569
        if (self.max_num_batched_tokens < self.max_model_len
                and not self.chunked_prefill_enabled):
1570
1571
1572
1573
1574
1575
1576
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
1577

1578
1579
1580
1581
1582
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
1583

1584
1585
1586
1587
1588
1589
        if self.num_lookahead_slots < 0:
            raise ValueError(
                "num_lookahead_slots "
                f"({self.num_lookahead_slots}) must be greater than or "
                "equal to 0.")

1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
        if self.num_scheduler_steps < 1:
            raise ValueError(
                "num_scheduler_steps "
                f"({self.num_scheduler_steps}) must be greater than or "
                "equal to 1.")

    @property
    def is_multi_step(self) -> bool:
        return self.num_scheduler_steps > 1

1600

1601
class DeviceConfig:
1602
    device: Optional[torch.device]
1603
    device_type: str
1604

1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # the device/platform information will be summarized
        # by torch/vllm automatically.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

1624
1625
1626
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
1627
            from vllm.platforms import current_platform
1628
            self.device_type = current_platform.device_type
1629
            if not self.device_type:
1630
                raise RuntimeError("Failed to infer device type")
1631
1632
1633
1634
1635
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
1636
        if self.device_type in ["neuron", "openvino"]:
1637
            self.device = torch.device("cpu")
1638
1639
        elif self.device_type in ["tpu"]:
            self.device = None
1640
1641
1642
1643
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

1644

1645
1646
1647
1648
1649
1650
1651
class SpeculativeConfig:
    """Configuration for speculative decoding.

    The configuration is currently specialized to draft-model speculative
    decoding with top-1 proposals.
    """

1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # spec decode does not use `torch.compile` yet.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

1670
1671
1672
1673
1674
1675
    @staticmethod
    def maybe_create_spec_config(
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        target_dtype: str,
        speculative_model: Optional[str],
1676
        speculative_model_quantization: Optional[str],
1677
        speculative_draft_tensor_parallel_size: Optional[int],
1678
        num_speculative_tokens: Optional[int],
1679
        speculative_disable_mqa_scorer: Optional[bool],
1680
1681
        speculative_max_model_len: Optional[int],
        enable_chunked_prefill: bool,
1682
        disable_log_stats: bool,
1683
        speculative_disable_by_batch_size: Optional[int],
1684
1685
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1686
1687
1688
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: Optional[float],
        typical_acceptance_sampler_posterior_alpha: Optional[float],
1689
        disable_logprobs: Optional[bool],
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
    ) -> Optional["SpeculativeConfig"]:
        """Create a SpeculativeConfig if possible, else return None.

        This function attempts to create a SpeculativeConfig object based on the
        provided parameters. If the necessary conditions are met, it returns an
        instance of SpeculativeConfig. Otherwise, it returns None.

        Args:
            target_model_config (ModelConfig): The configuration of the target
                model.
            target_parallel_config (ParallelConfig): The parallel configuration
                for the target model.
            target_dtype (str): The data type used for the target model.
            speculative_model (Optional[str]): The name of the speculative
                model, if provided.
1705
1706
1707
            speculative_model_quantization (Optional[str]): Quantization method
                that was used to quantize the speculative model weights. If
                None, we assume the model weights are not quantized.
1708
1709
            speculative_draft_tensor_parallel_size (Optional[int]): The degree
                of the tensor parallelism for the draft model.
1710
            num_speculative_tokens (Optional[int]): The number of speculative
1711
1712
                tokens, if provided. Will default to the number in the draft
                model config if present, otherwise is required.
1713
1714
1715
            speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
                scorer for the speculative model and fall back to batch
                expansion for scoring.
1716
1717
1718
1719
1720
1721
            speculative_max_model_len (Optional[int]): The maximum model len of
                the speculative model. Used when testing the ability to skip
                speculation for some sequences.
            enable_chunked_prefill (bool): Whether vLLM is configured to use
                chunked prefill or not. Used for raising an error since its not
                yet compatible with spec decode.
1722
1723
1724
            speculative_disable_by_batch_size (Optional[int]): Disable
                speculative decoding for new incoming requests when the number
                of enqueue requests  is larger than this value, if provided.
1725
1726
1727
1728
            ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
                window, if provided.
            ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
                window, if provided.
1729
1730
1731
1732
1733
1734
1735
1736
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
1737
                accepted. This threshold is used only when we use the
1738
1739
1740
1741
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1742
1743
1744
1745
1746
            disable_logprobs (Optional[bool]): If set to True, token log
                probabilities are not returned during speculative decoding.
                If set to False, token log probabilities are returned
                according to the log probability settings in SamplingParams.
                If not specified, it defaults to True.
1747

1748
1749
1750
1751
1752
        Returns:
            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                the necessary conditions are met, else None.
        """

1753
1754
1755
1756
        if speculative_model is None:
            if num_speculative_tokens is not None:
                raise ValueError("num_speculative_tokens was provided without "
                                 "speculative_model.")
1757
1758
            return None

1759
1760
1761
1762
1763
        if (speculative_disable_by_batch_size is not None
                and speculative_disable_by_batch_size < 2):
            raise ValueError("Expect the batch size threshold of disabling "
                             "speculative decoding is > 1, but got "
                             f"{speculative_disable_by_batch_size=}")
1764
1765
        if (enable_chunked_prefill and speculative_model == "eagle"):
            raise ValueError("Chunked prefill and EAGLE are not compatible.")
1766
1767
        # TODO: The user should be able to specify revision/max model len
        # for the draft model. It is not currently supported.
1768
1769
        draft_revision = None
        draft_code_revision = None
1770
        draft_quantization = speculative_model_quantization
1771

1772
1773
        if speculative_model == "[ngram]":
            if ngram_prompt_lookup_min is None:
1774
1775
1776
1777
1778
1779
1780
1781
                ngram_prompt_lookup_min = 1
            if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
                raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
            if ngram_prompt_lookup_min < 1:
                raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
            if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
                raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
                                 f"larger than {ngram_prompt_lookup_max=}")
1782

1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            draft_model_config = target_model_config
            draft_parallel_config = target_parallel_config
        else:
            ngram_prompt_lookup_max = 0
            ngram_prompt_lookup_min = 0
            draft_model_config = ModelConfig(
                model=speculative_model,
1793
                task="draft",
1794
1795
1796
                tokenizer=target_model_config.tokenizer,
                tokenizer_mode=target_model_config.tokenizer_mode,
                trust_remote_code=target_model_config.trust_remote_code,
1797
1798
                allowed_local_media_path=target_model_config.
                allowed_local_media_path,
1799
1800
1801
1802
1803
1804
                dtype=target_model_config.dtype,
                seed=target_model_config.seed,
                revision=draft_revision,
                code_revision=draft_code_revision,
                tokenizer_revision=target_model_config.tokenizer_revision,
                max_model_len=None,
1805
                spec_target_max_model_len=target_model_config.max_model_len,
1806
1807
                quantization=draft_quantization,
                enforce_eager=target_model_config.enforce_eager,
1808
1809
                max_seq_len_to_capture=target_model_config.
                max_seq_len_to_capture,
1810
1811
1812
                max_logprobs=target_model_config.max_logprobs,
            )

1813
            draft_hf_config = draft_model_config.hf_config
1814

1815
1816
1817
1818
1819
            if (num_speculative_tokens is not None
                    and hasattr(draft_hf_config, "num_lookahead_tokens")):
                draft_hf_config.num_lookahead_tokens = num_speculative_tokens

            n_predict = getattr(draft_hf_config, "n_predict", None)
1820
1821
1822
1823
1824
1825
1826
1827
            if n_predict is not None:
                if num_speculative_tokens is None:
                    # Default to max value defined in draft model config.
                    num_speculative_tokens = n_predict
                elif num_speculative_tokens > n_predict:
                    # Verify provided value doesn't exceed the maximum
                    # supported by the draft model.
                    raise ValueError(
1828
1829
1830
                        "This speculative model supports a maximum of "
                        f"num_speculative_tokens={n_predict}, but "
                        f"{num_speculative_tokens=} was provided.")
1831

1832
1833
1834
1835
1836
1837
1838
            speculative_draft_tensor_parallel_size = \
                SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
                    target_parallel_config,
                    speculative_draft_tensor_parallel_size,
                    draft_hf_config
            )

1839
1840
1841
1842
1843
1844
1845
1846
1847
            draft_model_config.max_model_len = (
                SpeculativeConfig._maybe_override_draft_max_model_len(
                    speculative_max_model_len,
                    draft_model_config.max_model_len,
                    target_model_config.max_model_len,
                ))

            draft_parallel_config = (
                SpeculativeConfig.create_draft_parallel_config(
1848
                    target_parallel_config,
1849
                    speculative_draft_tensor_parallel_size, draft_hf_config))
1850

1851
1852
1853
1854
1855
1856
        if num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative_model unless the draft model config contains an "
                "n_predict parameter.")

1857
1858
1859
1860
        if typical_acceptance_sampler_posterior_threshold is None:
            typical_acceptance_sampler_posterior_threshold = 0.09
        if typical_acceptance_sampler_posterior_alpha is None:
            typical_acceptance_sampler_posterior_alpha = 0.3
1861
1862
        if disable_logprobs is None:
            disable_logprobs = True
1863

1864
1865
1866
1867
        return SpeculativeConfig(
            draft_model_config,
            draft_parallel_config,
            num_speculative_tokens,
1868
            speculative_disable_mqa_scorer,
1869
            speculative_disable_by_batch_size,
1870
1871
            ngram_prompt_lookup_max,
            ngram_prompt_lookup_min,
1872
1873
1874
1875
1876
            draft_token_acceptance_method=draft_token_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=\
                typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=\
                typical_acceptance_sampler_posterior_alpha,
1877
1878
            disable_logprobs=disable_logprobs,
            disable_log_stats=disable_log_stats,
1879
1880
        )

1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
    @staticmethod
    def _maybe_override_draft_max_model_len(
        speculative_max_model_len: Optional[int],
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:

            if speculative_max_model_len > draft_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {draft_max_model_len=}")

            if speculative_max_model_len > target_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {target_max_model_len=}")

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

1916
    @staticmethod
1917
1918
1919
1920
1921
1922
1923
    def _verify_and_get_draft_model_tensor_parallel_size(
            target_parallel_config: ParallelConfig,
            speculative_draft_tensor_parallel_size: Optional[int],
            draft_hf_config: PretrainedConfig) -> int:
        """
        Verifies and adjusts the tensor parallel size for a draft model
        specified using speculative_draft_tensor_parallel_size.
1924
        """
1925
1926
        # If speculative_draft_tensor_parallel_size is unset then set it
        # appropriately else verify that it is set correctly.
1927
        if speculative_draft_tensor_parallel_size is None:
1928
1929
1930
1931
1932
1933
1934
1935
1936
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "MLPSpeculator cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1")
            else:
                speculative_draft_tensor_parallel_size = \
                    target_parallel_config.tensor_parallel_size
1937
1938
        elif speculative_draft_tensor_parallel_size not in (
                1, target_parallel_config.tensor_parallel_size):
1939
            raise ValueError(
1940
                f"{speculative_draft_tensor_parallel_size=} cannot be "
1941
                f"other value than 1 or target model tensor_parallel_size")
1942
        return speculative_draft_tensor_parallel_size
1943

1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
    @staticmethod
    def create_draft_parallel_config(
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: int,
        draft_hf_config: PretrainedConfig,
    ) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config, except the tp_size.
        """
1954
1955
1956
        draft_parallel_config = ParallelConfig(
            pipeline_parallel_size=target_parallel_config.
            pipeline_parallel_size,
1957
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
1958
1959
            distributed_executor_backend=target_parallel_config.
            distributed_executor_backend,
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
            max_parallel_loading_workers=target_parallel_config.
            max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.
            disable_custom_all_reduce,
            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
            ray_workers_use_nsight=target_parallel_config.
            ray_workers_use_nsight,
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

    def __init__(
        self,
        draft_model_config: ModelConfig,
        draft_parallel_config: ParallelConfig,
        num_speculative_tokens: int,
1977
        speculative_disable_mqa_scorer: Optional[bool],
1978
1979
1980
        speculative_disable_by_batch_size: Optional[int],
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1981
1982
1983
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
1984
        disable_logprobs: bool,
1985
        disable_log_stats: bool,
1986
1987
1988
1989
1990
1991
1992
1993
    ):
        """Create a SpeculativeConfig object.

        Args:
            draft_model_config: ModelConfig for the draft model.
            draft_parallel_config: ParallelConfig for the draft model.
            num_speculative_tokens: The number of tokens to sample from the
                draft model before scoring with the target model.
1994
1995
1996
1997
1998
            speculative_disable_by_batch_size: Disable speculative
                decoding for new incoming requests when the number of
                enqueue requests is larger than this value.
            ngram_prompt_lookup_max: Max size of ngram token window.
            ngram_prompt_lookup_min: Min size of ngram token window.
1999
2000
2001
2002
2003
2004
2005
2006
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
2007
                accepted. This threshold is used only when we use the
2008
2009
2010
2011
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
2012
            disable_logprobs: If set to True, token log probabilities will not
2013
                be returned even if requested by sampling parameters. This
2014
2015
2016
2017
                reduces latency by skipping logprob calculation in proposal
                sampling, target sampling, and after accepted tokens are
                determined. If set to False, log probabilities will be
                returned.
2018
2019
            disable_log_stats: Whether to disable periodic printing of stage
                times in speculative decoding.
2020
2021
2022
2023
        """
        self.draft_model_config = draft_model_config
        self.draft_parallel_config = draft_parallel_config
        self.num_speculative_tokens = num_speculative_tokens
2024
        self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
2025
2026
2027
2028
        self.speculative_disable_by_batch_size = \
            speculative_disable_by_batch_size
        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
2029
2030
2031
2032
2033
        self.draft_token_acceptance_method = draft_token_acceptance_method
        self.typical_acceptance_sampler_posterior_threshold = \
            typical_acceptance_sampler_posterior_threshold
        self.typical_acceptance_sampler_posterior_alpha = \
            typical_acceptance_sampler_posterior_alpha
2034
        self.disable_logprobs = disable_logprobs
2035
        self.disable_log_stats = disable_log_stats
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046

        self._verify_args()

    def _verify_args(self) -> None:
        if self.num_speculative_tokens <= 0:
            raise ValueError("Expected num_speculative_tokens to be greater "
                             f"than zero ({self.num_speculative_tokens}).")

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
                self.draft_parallel_config)
2047
2048
2049
2050
2051
2052
2053
2054
            # Validate and set draft token acceptance related settings.

        if (self.draft_token_acceptance_method is None):
            raise ValueError("draft_token_acceptance_method is not set. "
                             "Expected values are rejection_sampler or "
                             "typical_acceptance_sampler.")

        if (self.draft_token_acceptance_method != 'rejection_sampler'
2055
2056
                and self.draft_token_acceptance_method
                != 'typical_acceptance_sampler'):
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
            raise ValueError(
                "Expected draft_token_acceptance_method to be either "
                "rejection_sampler or typical_acceptance_sampler. Instead it "
                f"is {self.draft_token_acceptance_method}")

        if (self.typical_acceptance_sampler_posterior_threshold < 0
                or self.typical_acceptance_sampler_posterior_alpha < 0):
            raise ValueError(
                "Expected typical_acceptance_sampler_posterior_threshold "
                "and typical_acceptance_sampler_posterior_alpha to be > 0. "
                "Instead found "
                f"typical_acceptance_sampler_posterior_threshold = "
                f"{self.typical_acceptance_sampler_posterior_threshold} and "
                f"typical_acceptance_sampler_posterior_alpha = "
                f"{self.typical_acceptance_sampler_posterior_alpha}")
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def __repr__(self) -> str:
2084
2085
2086
2087
        if self.ngram_prompt_lookup_max > 0:
            draft_model = "[ngram]"
        else:
            draft_model = self.draft_model_config.model
2088
2089
2090
2091
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


2092
2093
2094
2095
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
2096
    fully_sharded_loras: bool = False
2097
    max_cpu_loras: Optional[int] = None
2098
    lora_dtype: Optional[Union[torch.dtype, str]] = None
2099
2100
2101
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256
2102
    long_lora_scaling_factors: Optional[Tuple[float]] = None
2103
    bias_enabled: bool = False
2104

2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # LoRA is not compatible with `torch.compile` .
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

2123
    def __post_init__(self):
2124
2125
2126
        # Setting the maximum rank to 256 should be able to satisfy the vast
        # majority of applications.
        possible_max_ranks = (8, 16, 32, 64, 128, 256)
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
2143
                f"max_loras ({self.max_loras})")
2144

2145
2146
2147
2148
2149
    def verify_with_cache_config(self, cache_config: CacheConfig):
        # TODO LoRA supports CPU offload.
        if cache_config.cpu_offload_gb > 0:
            raise ValueError("CPU offload is not supported with LoRA yet.")

2150
2151
2152
2153
2154
    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
2155
2156
2157
        if model_config.quantization and model_config.quantization not in [
                "awq", "gptq"
        ]:
2158
            # TODO support marlin
2159
2160
            logger.warning("%s quantization is not tested with LoRA yet.",
                           model_config.quantization)
2161
2162

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
2163
        # Reminder: Please update docs/source/features/compatibility_matrix.md
2164
        # If the feature combo become valid
2165
        if scheduler_config.chunked_prefill_enabled:
2166
2167
            logger.warning("LoRA with chunked prefill is still experimental "
                           "and may be unstable.")
2168
2169


2170
2171
2172
2173
2174
2175
2176
@dataclass
class PromptAdapterConfig:
    max_prompt_adapters: int
    max_prompt_adapter_token: int
    max_cpu_prompt_adapters: Optional[int] = None
    prompt_adapter_dtype: Optional[torch.dtype] = None

2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
    def __post_init__(self):

        if self.max_prompt_adapters < 1:
            raise ValueError(f"max_prompt_adapters "
                             f"({self.max_prompt_adapters}) must be >= 1.")
        if self.max_prompt_adapter_token == 0:
            raise ValueError("max_prompt_adapter_token must be set.")
        if self.max_cpu_prompt_adapters is None:
            self.max_cpu_prompt_adapters = self.max_prompt_adapters

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.prompt_adapter_dtype in (None, "auto"):
            self.prompt_adapter_dtype = model_config.dtype
        elif isinstance(self.prompt_adapter_dtype, str):
            self.prompt_adapter_dtype = getattr(torch,
                                                self.prompt_adapter_dtype)


2213
@dataclass
2214
class MultiModalConfig:
2215
2216
    """Controls the behavior of multimodal models."""

2217
    limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
2218
    """
2219
    The maximum number of input items allowed per prompt for each modality.
2220
2221
    """

2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

2240
    # TODO: Add configs to init vision tower or not.
2241

2242

2243
2244
@dataclass
class PoolerConfig:
2245
    """Controls the behavior of output pooling in pooling models."""
2246
2247

    pooling_type: Optional[str] = None
2248
    """
2249
    The pooling method of the pooling model. This should be a key in
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
    :class:`vllm.model_executor.layers.pooler.PoolingType`.
    """

    normalize: Optional[bool] = None
    """
    Whether to normalize the pooled outputs. Usually, this should be set to
    ``True`` for embedding outputs.
    """

    softmax: Optional[bool] = None
    """
    Whether to apply softmax to the pooled outputs. Usually, this should be set
    to ``True`` for classification outputs.
    """

    step_tag_id: Optional[int] = None
    """
2267
    If set, only the score corresponding to the ``step_tag_id`` in the
2268
2269
2270
2271
2272
2273
    generated sentence should be returned. Otherwise, the scores for all tokens
    are returned.
    """

    returned_token_ids: Optional[List[int]] = None
    """
2274
2275
    A list of indices for the vocabulary dimensions to be extracted,
    such as the token IDs of ``good_token`` and ``bad_token`` in the
2276
2277
2278
    ``math-shepherd-mistral-7b-prm`` model.
    """

2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

2297
2298
2299
    @staticmethod
    def from_json(json_str: str) -> "PoolerConfig":
        return PoolerConfig(**json.loads(json_str))
2300
2301


2302
2303
2304
2305
2306
2307
2308
2309
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

2310
_ROCM_NOT_SUPPORTED_DTYPE: List[str] = []  #
2311

2312
2313
2314

def _get_and_verify_dtype(
    config: PretrainedConfig,
2315
    dtype: Union[str, torch.dtype],
2316
2317
2318
2319
2320
2321
2322
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

2323
2324
2325
2326
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
Woosuk Kwon's avatar
Woosuk Kwon committed
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
                if config.model_type == "gemma2":
                    logger.info(
                        "For Gemma 2, we downcast float32 to bfloat16 instead "
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16.")
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
2337
2338
            else:
                torch_dtype = config_dtype
2339

2340
            from vllm.platforms import current_platform
2341
2342
            if (current_platform.is_cpu()
                    and current_platform.get_cpu_architecture()
2343
                    == CpuArchEnum.POWERPC
2344
2345
2346
2347
2348
2349
2350
2351
                    and (config_dtype == torch.float16
                         or config_dtype == torch.float32)):
                logger.info(
                    "For POWERPC, we cast models to bfloat16 instead of "
                    "using float16 by default. Float16 is not currently "
                    "supported for POWERPC.")
                torch_dtype = torch.bfloat16

2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
            # TODO: change this condition to check if the platform support bf16
            # instead of checking the OS. For instance M2 shall supports bf16
            # already. But we need to modify `cpu_extension.cmake` to activate
            # the feature in the build.
            if (current_platform.is_cpu() and sys.platform.startswith("darwin")
                    and current_platform.get_cpu_architecture()
                    == CpuArchEnum.ARM and config_dtype == torch.bfloat16):
                logger.info("For macOS with Apple Silicon, currently bfloat16 "
                            "is not supported. Setting dtype to float16.")
                torch_dtype = torch.float16

2363
2364
2365
2366
2367
2368
            if current_platform.is_hpu() and config_dtype == torch.float16:
                logger.info(
                    "For HPU, we cast models to bfloat16 instead of"
                    "using float16 by default. Please specify `dtype` if you "
                    "want to use float16.")
                torch_dtype = torch.bfloat16
2369
        else:
2370
2371
2372
2373
2374
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
2375
    else:
2376
        raise ValueError(f"Unknown dtype: {dtype}")
2377
2378
2379
2380
2381

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
2382
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
2383
2384
2385
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
2386
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
2387
2388
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
2389
            # Casting between float16 and bfloat16 is allowed with a warning.
2390
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
2391
2392

    return torch_dtype
2393
2394
2395
2396
2397


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
2398
    disable_sliding_window: bool,
2399
    sliding_window_len: Optional[Union[int, List[Optional[int]]]],
2400
    spec_target_max_model_len: Optional[int] = None,
2401
    encoder_config: Optional[Any] = None,
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
2412
2413
        # ChatGLM2
        "seq_length",
2414
2415
        # Command-R
        "model_max_length",
2416
2417
        # Whisper
        "max_target_positions",
2418
2419
2420
2421
2422
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
2423
    # Choose the smallest "max_length" from the possible keys.
2424
    max_len_key = None
2425
    for key in possible_keys:
2426
2427
2428
2429
2430
        max_len = getattr(hf_config, key, None)
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
2431
2432
2433
2434

    # If sliding window is manually disabled, max_length should be less
    # than the sliding window length in the model config.
    if disable_sliding_window and sliding_window_len is not None:
2435
2436

        sliding_window_len_min = get_min_sliding_window(sliding_window_len)
2437
        max_len_key = "sliding_window" \
2438
2439
2440
            if sliding_window_len_min < derived_max_model_len else max_len_key
        derived_max_model_len = min(derived_max_model_len,
                                    sliding_window_len_min)
2441
2442
2443

    # If none of the keys were found in the config, use a default and
    # log a warning.
2444
    if derived_max_model_len == float("inf"):
2445
2446
2447
2448
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

2449
2450
2451
2452
2453
        if spec_target_max_model_len is not None:
            # If this is a speculative draft model, we use the max model len
            # from the target model.
            return spec_target_max_model_len

2454
2455
2456
2457
        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
2458
            "%s. Assuming the model's maximum length is %d.", possible_keys,
2459
            default_max_len)
2460
        derived_max_model_len = default_max_len
2461

2462
    rope_scaling = getattr(hf_config, "rope_scaling", None)
2463
    if rope_scaling is not None:
2464
2465
2466
        # No need to consider "type" key because of patch_rope_scaling when
        # loading HF config
        rope_type = rope_scaling["rope_type"]
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476

        if rope_type not in ("su", "longrope", "llama3"):
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that supports rope_scaling
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "with rope_scaling. Please raise an issue so we can "
                    "investigate.")

2477
2478
2479
2480
            # NOTE: rope_type == "default" does not define factor
            # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
            scaling_factor = rope_scaling.get("factor", 1.0)

2481
2482
2483
2484
            if rope_type == "yarn":
                derived_max_model_len = rope_scaling[
                    "original_max_position_embeddings"]
            derived_max_model_len *= scaling_factor
2485

2486
2487
2488
    if encoder_config and "max_seq_length" in encoder_config:
        derived_max_model_len = encoder_config["max_seq_length"]

2489
2490
    # If the user specified a max length, make sure it is smaller than the
    # derived length from the HF model config.
2491
    if max_model_len is None:
2492
        max_model_len = int(derived_max_model_len)
2493
    elif max_model_len > derived_max_model_len:
2494
2495
2496
2497
2498
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, "model_max_length", None)
        if model_max_length is not None and max_model_len <= model_max_length:
2499
2500
2501
2502
2503
2504
2505
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that has model_max_length
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "model_max_length in the config. Please raise an issue "
                    "so we can investigate.")
2506
        else:
2507
            msg = (
2508
                f"User-specified max_model_len ({max_model_len}) is greater "
2509
2510
                f"than the derived max_model_len ({max_len_key}="
                f"{derived_max_model_len} or model_max_length="
2511
                f"{model_max_length} in model's config.json). This may lead "
2512
2513
2514
2515
2516
2517
2518
2519
2520
                "to incorrect model outputs or CUDA errors.")
            if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
                logger.warning(
                    "%s Make sure the value is correct and within the "
                    "model context size.", msg)
            else:
                raise ValueError(
                    f"{msg} To allow overriding this maximum, set "
                    "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
2521
    return int(max_model_len)
2522
2523


2524
2525
2526
2527
2528
2529
2530
2531
def get_min_sliding_window(
        sliding_window: Union[int, List[Optional[int]]]) -> int:
    if isinstance(sliding_window, list):
        return min(s for s in sliding_window if s is not None)

    return sliding_window


2532
2533
2534
def get_served_model_name(model: str,
                          served_model_name: Optional[Union[str, List[str]]]):
    """
2535
2536
2537
2538
    If the input is a non-empty list, the first model_name in
    `served_model_name` is taken.
    If the input is a non-empty string, it is used directly.
    For cases where the input is either an empty string or an
2539
2540
2541
2542
2543
2544
2545
2546
2547
    empty list, the fallback is to use `self.model`.
    """
    if not served_model_name:
        return model
    if isinstance(served_model_name, list):
        return served_model_name[0]
    return served_model_name


2548
2549
2550
2551
@dataclass
class DecodingConfig:
    """Dataclass which contains the decoding strategy of the engine"""

2552
2553
2554
    # Which guided decoding algo to use.
    # 'outlines' / 'lm-format-enforcer' / 'xgrammar'
    guided_decoding_backend: str = 'xgrammar'
2555

2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

2574
    def __post_init__(self):
2575
        valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
2576
2577
2578
2579
2580
2581
        backend = self.guided_decoding_backend
        if backend not in valid_guided_backends:
            raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                             f"must be one of {valid_guided_backends}")


2582
2583
2584
2585
2586
@dataclass
class ObservabilityConfig:
    """Configuration for observability."""
    otlp_traces_endpoint: Optional[str] = None

2587
2588
2589
2590
2591
2592
2593
2594
    # Collecting detailed timing information for each request can be expensive.

    # If set, collects the model forward time for the request.
    collect_model_forward_time: bool = False

    # If set, collects the model execute time for the request.
    collect_model_execute_time: bool = False

2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

2613
    def __post_init__(self):
2614
2615
2616
2617
2618
        if not is_otel_available() and self.otlp_traces_endpoint is not None:
            raise ValueError(
                "OpenTelemetry is not available. Unable to configure "
                "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
                f"installed. Original error:\n{otel_import_error_traceback}")
2619
2620


2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
class KVTransferConfig(BaseModel):
    """Configuration for distributed KV cache transfer."""

    # The KV connector for vLLM to transmit KV caches between vLLM instances.
    kv_connector: Optional[str] = None

    # The device used by kv connector to buffer the KV cache.
    # Currently only support 'cuda'.
    kv_buffer_device: Optional[str] = "cuda"

    # The buffer size for TorchDistributedConnector. Measured in number of
    # bytes. Recommended value: 1e9 (about 1GB).
    kv_buffer_size: float = 1e9

    # Whether this vLLM instance produces, consumes KV cache, or both. Choices
    # are 'kv_producer', 'kv_consumer', and 'both'.
    kv_role: Optional[str] = None

    # The rank of this vLLM instance in the KV cache transfer. Typical value:
    # 0 for prefill instance, 1 for decode instance.
    # Currently only 1P1D is supported.
    kv_rank: Optional[int] = None

    # The number of parallel instances for KV cache transfer. For
    # PyNcclConnector, this should be 2.
    kv_parallel_size: int = 1

    # The KV connector ip, used to build distributed connection
    kv_ip: str = "127.0.0.1"

    # The KV connector port, used to build distributed connection
    kv_port: int = 14579

2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: List[Any] = []
        hash_str = hashlib.md5(str(factors).encode()).hexdigest()
        return hash_str

2672
2673
    @classmethod
    def from_cli(cls, cli_value: str) -> "KVTransferConfig":
youkaichao's avatar
youkaichao committed
2674
        """Parse the CLI value for the kv cache transfer config."""
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
        return KVTransferConfig.model_validate_json(cli_value)

    def model_post_init(self, __context: Any) -> None:

        if self.kv_role is not None and self.kv_role not in [
                "kv_producer", "kv_consumer", "kv_both"
        ]:
            raise ValueError(
                f"Unsupported kv_role: {self.kv_role}. "
                f"Supported roles are `kv_producer`, `kv_consumer`, "
                f"and `kv_both`")

        if self.kv_connector is not None and self.kv_role is None:
            raise ValueError("Please specify kv_disagg_role when kv_connector "
                             "is set, supported roles are `kv_producer`, "
                             "`kv_consumer`, and `kv_both`")

    @property
    def is_kv_transfer_instance(self) -> bool:
        return self.kv_connector is not None and \
            self.kv_role in ["kv_producer", "kv_consumer", "kv_both"]

    @property
    def need_kv_parallel_group(self) -> bool:
        # for those database-based connector, vLLM does not need to create
        # parallel group, and in that case the kv parallel size will be 1.
        return self.kv_connector is not None and self.kv_parallel_size > 1

    @property
    def is_kv_producer(self) -> bool:
        return self.kv_connector is not None and \
            self.kv_role in ["kv_producer", "kv_both"]

    @property
    def is_kv_consumer(self) -> bool:
        return self.kv_connector is not None and \
            self.kv_role in ["kv_consumer", "kv_both"]


2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
class CompilationLevel:
    # constants for the levels of the compilation process
    NO_COMPILATION = 0
    DYNAMO_AS_IS = 1
    DYNAMO_ONCE = 2
    PIECEWISE = 3


class CompilationConfig(BaseModel):
    """
    Configuration for compilation.
    It has three parts:
    - Top-level Compilation control:
        - level: the level of compilation.
            - 0: no compilation.
            - 1: dynamo as is.
            - 2: dynamo once.
            - 3: piecewise compilation.
2732
        - debug_dump_path: the path to dump the debug information.
2733
2734
2735
        - cache_dir: the directory to store the compiled graph, to
            accelerate Inductor compilation. By default, it will use
            model-related information to generate a cache directory.
2736
2737
2738
2739
2740
2741
2742
        - backend: the backend for compilation. It needs to be a string.
            - "" (empty string): use the default backend.
            - "eager"/"openxla"/...: use the specified backend registered in PyTorch.
            - "full.module.name": a qualified name which can be used to import the backend function.
            We use string to avoid serialization issues when using compilation in a distributed setting.
            When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
            When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
2743
2744
2745
2746
2747
2748
2749
2750
2751
        - custom_ops: fine-grained control over which custom ops to enable/disable.
            Use 'all' to enable all, 'none' to disable all.
            Also specify a list of custom op names to enable (prefixed with a '+'),
            or disable (prefixed with a '-').
            Examples:
                - 'all,-op1' to enable all except op1
                - 'none,+op1,+op2' to enable only op1 and op2
            By default, all custom ops are enabled when running without Inductor
                and disabled when running with Inductor (compile_level >= Inductor).
2752
        - splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
2753
2754
2755
2756
    - CudaGraph capture:
        - use_cudagraph: whether to use cudagraph inside compilation.
            - False: cudagraph inside compilation is not used.
            - True: cudagraph inside compilation is used. It requires
2757
2758
2759
2760
                that all input buffers have fixed addresses, and all
                splitting ops write their outputs to input buffers.
            Note that this is orthogonal to the cudagraph capture logic
            outside of compilation.
2761
2762
2763
            TODO: move outside cudagraph logic into compilation.
            torch.compile will handle cudagraph capture logic in the future.
        - cudagraph_capture_sizes: sizes to capture cudagraph.
2764
2765
            - None (default): capture sizes are inferred from vllm config.
            - List[int]: capture sizes are specified as given.
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
        - cudagraph_num_of_warmups: number of warmup runs for cudagraph.
            It means the first several runs will be treated as warmup runs.
            Only after that, the execution will be recorded, and the recorded
            cudagraph will be used for subsequent runs.
        - cudagraph_copy_inputs: whether to copy input tensors for
            cudagraph. If the caller can guarantee that the same input buffers
            are always used, it can set this to False. Otherwise, it should
            set this to True, and the compiler will copy the input to an
            internally managed buffer. Default is False.
    - Inductor compilation:
        - use_inductor: whether to use inductor compilation.
            - False: inductor compilation is not used. graph runs in eager.
            - True: inductor compilation is used. one graph for symbolic shape
2779
2780
2781
2782
2783
                is compiled. In addition, compile for compile_sizes,
                using configurations in inductor_compile_config.
        - compile_sizes: sizes to compile for inductor. In addition
            to integers, it also supports "cudagraph_capture_sizes" to
            specify the sizes for cudagraph capture.
2784
2785
2786
2787
2788
2789
2790
        - inductor_compile_config: additional configurations for inductor.
            - None: use default configurations.
        - inductor_passes: additional passes for inductor. It is a dictionary
            from pass name to pass function qualified name. We use function
            name because the config uses json format. If we pass the config
            from Python, functions can also be passed directly via Python object
            constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
2791
        - custom inductor passes: see PassConfig for more details
2792

2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
    Why we have different sizes for cudagraph and inductor:
    - cudagraph: a cudagraph captured for a specific size can only be used
        for the same size. We need to capture all the sizes we want to use.
    - inductor: a graph compiled by inductor for a general shape can be used
        for different sizes. Inductor can also compile for specific sizes,
        where it can have more information to optimize the graph with fully
        static shapes. However, we find the general shape compilation is
        sufficient for most cases. It might be beneficial to compile for
        certain small batchsizes, where inductor is good at optimizing.
    """ # noqa
    level: int = 0
2804
    debug_dump_path: str = ""
2805
    cache_dir: str = ""
2806
    backend: str = ""
2807
    custom_ops: List[str] = Field(default_factory=list)
2808
    splitting_ops: List[str] = Field(default=None)  # type: ignore
2809
2810

    use_inductor: bool = True
2811
    compile_sizes: Optional[List[Union[int, str]]] = Field(default=None)
2812
2813
2814
2815
2816
2817
2818
2819
    inductor_compile_config: Dict = Field(default_factory=dict)
    inductor_passes: Dict[str, str] = Field(default_factory=dict)

    use_cudagraph: bool = False
    cudagraph_num_of_warmups: int = 0
    cudagraph_capture_sizes: Optional[List[int]] = None
    cudagraph_copy_inputs: bool = False

2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
    class PassConfig(BaseModel):
        """
        Configuration for custom Inductor passes.
        This is separate from general CompilationConfig so that inductor passes
        don't all have access to full configuration - that would create a cycle
        as the PassManager is set as a property of config.
        - dump_graph_stages: list of stages for which we want to dump the graph.
            Each pass defines its own stages (before, after, maybe in-between).
        - dump_graph_dir: directory to dump the graphs. Default is .
        - enable_fusion: whether to enable the custom fusion pass.
        - enable_reshape: whether to enable the custom reshape elimination pass.
            TODO better pass enabling system.
        """
        dump_graph_stages: List[str] = Field(default_factory=list)
        dump_graph_dir: Path = Field(default=Path("."))
        enable_fusion: bool = True
        enable_reshape: bool = True

        def uuid(self):
            """
            Produces a hash unique to the pass configuration.
            Any new fields that affect compilation should be added to the hash.
            Do not include dump_graph_* in the hash - they don't affect
            compilation.
            """
            dict_ = self.model_dump(
                include={"enable_fusion", "enable_reshape"})
            encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
            return hashlib.sha256(encoded).digest()

        def model_post_init(self, __context: Any) -> None:
            if not self.enable_reshape and self.enable_fusion:
2852
                logger.warning_once(
2853
2854
2855
2856
                    "Fusion enabled but reshape elimination disabled."
                    "RMSNorm + quant (fp8) fusion might not work")

    pass_config: PassConfig = Field(default_factory=PassConfig)
2857
2858

    # not configurable, computed after init
2859
    max_capture_size: int = PrivateAttr
2860
    local_cache_dir: str = PrivateAttr  # local cache dir for each rank
2861
2862
2863
2864
2865
    # optimization:
    # Intuitively, bs_to_padded_graph_size should be Dict[int, int].
    # since we know all keys are in a range [0, max_capture_size],
    # we can optimize it to List[int] for better lookup performance.
    bs_to_padded_graph_size: List[int] = PrivateAttr
2866

2867
2868
2869
    # keep track of enabled and disabled custom ops
    enabled_custom_ops: Counter[str] = PrivateAttr
    disabled_custom_ops: Counter[str] = PrivateAttr
2870
    traced_files: Set[str] = PrivateAttr
2871
    compilation_time: float = PrivateAttr
2872

2873
2874
2875
2876
    # Per-model forward context
    # Map from layer name to the attention cls
    static_forward_context: Dict[str, Any] = PrivateAttr

2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: List[Any] = []
        factors.append(self.level)
        factors.append(self.backend)
        factors.append(self.custom_ops)
        factors.append(self.splitting_ops)
        factors.append(self.use_inductor)
        factors.append(self.inductor_compile_config)
        factors.append(self.inductor_passes)
        factors.append(self.pass_config.uuid())
        return hashlib.sha256(str(factors).encode()).hexdigest()

2900
2901
2902
2903
2904
2905
2906
2907
    def __repr__(self) -> str:
        exclude = {
            "static_forward_context",
            "enabled_custom_ops",
            "disabled_custom_ops",
            "compilation_time",
            "bs_to_padded_graph_size",
            "pass_config",
2908
            "traced_files",
2909
2910
2911
2912
2913
        }
        return self.model_dump_json(exclude=exclude, exclude_unset=True)

    __str__ = __repr__

2914
2915
2916
2917
2918
    @classmethod
    def from_cli(cls, cli_value: str) -> "CompilationConfig":
        """Parse the CLI value for the compilation config."""
        if cli_value in ["0", "1", "2", "3"]:
            return cls(level=int(cli_value))
2919
2920
2921
        # do not use `eval`, it is dangerous and can execute arbitrary code
        dict_value = ast.literal_eval(cli_value)
        return CompilationConfig.model_validate(dict_value)
2922

2923
2924
2925
2926
2927
2928
    def model_post_init(self, __context: Any) -> None:

        count_none = self.custom_ops.count("none")
        count_all = self.custom_ops.count("all")
        assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"

2929
2930
2931
2932
2933
2934
2935
2936
2937
        if self.splitting_ops is None:
            if envs.VLLM_USE_V1:
                # v1 must split the graph on attention ops
                # for piecewise cudagraph
                self.splitting_ops = [
                    "vllm.unified_attention",
                    "vllm.unified_attention_with_output",
                ]
            else:
2938
2939
                # v0 uses full graph compilation
                self.splitting_ops = []
2940

2941
2942
2943
        for k, v in self.inductor_passes.items():
            if not isinstance(v, str):
                assert callable(v), (
2944
2945
2946
                    f"pass {k} should be callable or a qualified name")
                self.inductor_compile_config[k] = v if isinstance(
                    v, InductorPass) else CallableInductorPass(v)
2947
2948
2949
2950
2951
2952
2953
                continue

            # resolve function from qualified name
            names = v.split(".")
            module = ".".join(names[:-1])
            func_name = names[-1]
            func = __import__(module).__dict__[func_name]
2954
2955
            self.inductor_compile_config[k] = func if isinstance(
                func, InductorPass) else CallableInductorPass(func)
2956

2957
2958
        self.enabled_custom_ops = Counter()
        self.disabled_custom_ops = Counter()
2959
        self.traced_files = set()
2960
        self.static_forward_context = {}
2961
        self.compilation_time = 0.0
2962

2963
    def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
        if self.level == CompilationLevel.NO_COMPILATION:
            raise ValueError("No compilation level is set.")

        from torch._dynamo.backends.registry import list_backends
        torch_backends = list_backends(exclude_tags=tuple())
        if self.level in [
                CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE
        ]:
            if self.backend == "":
                return "eager"
            if self.backend in torch_backends:
                return self.backend
            return resolve_obj_by_qualname(self.backend)

        # TODO: pass user-specified backend to piecewise compilation
        # merge with the config use_inductor
        assert self.level == CompilationLevel.PIECEWISE
2981

2982
        from vllm.compilation.backends import VllmBackend
2983
        return VllmBackend(vllm_config)
2984

2985
2986
    def init_with_cudagraph_sizes(self,
                                  cudagraph_capture_sizes: List[int]) -> None:
2987
        """To complete the initialization of config,
2988
2989
        we need to know the cudagraph sizes."""

2990
        if self.cudagraph_capture_sizes is None:
2991
            self.cudagraph_capture_sizes = cudagraph_capture_sizes
2992
        else:
2993
2994
2995
            # de-duplicate the sizes provided by the config
            self.cudagraph_capture_sizes = list(
                set(self.cudagraph_capture_sizes))
2996
2997
            logger.info(("cudagraph sizes specified by model runner"
                         " %s is overridden by config %s"),
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
                        cudagraph_capture_sizes, self.cudagraph_capture_sizes)

        computed_compile_sizes = []
        if self.compile_sizes is not None:
            # de-duplicate the sizes provided by the config
            self.compile_sizes = list(set(self.compile_sizes))
            for x in self.compile_sizes:
                if isinstance(x, str):
                    assert x == "cudagraph_capture_sizes", \
                    "Unrecognized size type in compile_sizes, " \
                    f"expect 'cudagraph_capture_sizes', got {x}"
                    computed_compile_sizes.extend(self.cudagraph_capture_sizes)
                else:
                    assert isinstance(x, int)
                    computed_compile_sizes.append(x)
        self.compile_sizes = computed_compile_sizes  # type: ignore
3014

3015
        # sort to make sure cudagraph capture sizes are in descending order
3016
3017
3018
        self.cudagraph_capture_sizes.sort(reverse=True)
        self.max_capture_size = self.cudagraph_capture_sizes[
            0] if self.cudagraph_capture_sizes else 0
3019

3020
3021
3022
3023
        # pre-compute the mapping from batch size to padded graph size
        self.bs_to_padded_graph_size = [
            0 for i in range(self.max_capture_size + 1)
        ]
3024
3025
        for end, start in zip(self.cudagraph_capture_sizes,
                              self.cudagraph_capture_sizes[1:] + [0]):
3026
3027
3028
3029
3030
3031
3032
            for bs in range(start, end):
                if bs == start:
                    self.bs_to_padded_graph_size[bs] = start
                else:
                    self.bs_to_padded_graph_size[bs] = end
        self.bs_to_padded_graph_size[
            self.max_capture_size] = self.max_capture_size
3033

3034

3035
3036
3037
@dataclass
class VllmConfig:
    """Dataclass which contains all vllm-related configuration. This
3038
3039
3040
    simplifies passing around the distinct configurations in the codebase.
    """

3041
3042
    model_config: ModelConfig = field(default=None, init=True)  # type: ignore
    cache_config: CacheConfig = field(default=None, init=True)  # type: ignore
3043
3044
3045
3046
    parallel_config: ParallelConfig = field(default_factory=ParallelConfig,
                                            init=True)
    scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig,
                                              init=True)
3047
3048
3049
    device_config: DeviceConfig = field(default=None,
                                        init=True)  # type: ignore
    load_config: LoadConfig = field(default=None, init=True)  # type: ignore
3050
3051
3052
3053
3054
    lora_config: Optional[LoRAConfig] = None
    speculative_config: Optional[SpeculativeConfig] = None
    decoding_config: Optional[DecodingConfig] = None
    observability_config: Optional[ObservabilityConfig] = None
    prompt_adapter_config: Optional[PromptAdapterConfig] = None
3055
    quant_config: Optional[QuantizationConfig] = None
3056
3057
    compilation_config: CompilationConfig = field(default=None,
                                                  init=True)  # type: ignore
3058
3059
    kv_transfer_config: KVTransferConfig = field(default=None,
                                                 init=True)  # type: ignore
3060
    # some opaque config, only used to provide additional information
3061
3062
    # for the hash computation, mainly used for testing, debugging or out of
    # tree config registration.
3063
3064
    additional_config: SupportsHash = field(default=None,
                                            init=True)  # type: ignore
3065
    instance_id: str = ""
3066

3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: List[Any] = []

        # summarize vllm config
        vllm_factors: List[Any] = []
        from vllm import __version__
        vllm_factors.append(__version__)
        if self.model_config:
            vllm_factors.append(self.model_config.compute_hash())
3087
3088
        else:
            vllm_factors.append("None")
3089
3090
        if self.cache_config:
            vllm_factors.append(self.cache_config.compute_hash())
3091
3092
        else:
            vllm_factors.append("None")
3093
3094
        if self.parallel_config:
            vllm_factors.append(self.parallel_config.compute_hash())
3095
3096
        else:
            vllm_factors.append("None")
3097
3098
        if self.scheduler_config:
            vllm_factors.append(self.scheduler_config.compute_hash())
3099
3100
        else:
            vllm_factors.append("None")
3101
3102
        if self.device_config:
            vllm_factors.append(self.device_config.compute_hash())
3103
3104
        else:
            vllm_factors.append("None")
3105
3106
        if self.load_config:
            vllm_factors.append(self.load_config.compute_hash())
3107
3108
        else:
            vllm_factors.append("None")
3109
3110
        if self.lora_config:
            vllm_factors.append(self.lora_config.compute_hash())
3111
3112
        else:
            vllm_factors.append("None")
3113
3114
        if self.speculative_config:
            vllm_factors.append(self.speculative_config.compute_hash())
3115
3116
        else:
            vllm_factors.append("None")
3117
3118
        if self.decoding_config:
            vllm_factors.append(self.decoding_config.compute_hash())
3119
3120
        else:
            vllm_factors.append("None")
3121
3122
        if self.observability_config:
            vllm_factors.append(self.observability_config.compute_hash())
3123
3124
        else:
            vllm_factors.append("None")
3125
3126
        if self.prompt_adapter_config:
            vllm_factors.append(self.prompt_adapter_config.compute_hash())
3127
3128
        else:
            vllm_factors.append("None")
3129
3130
3131
3132
        if self.quant_config:
            pass  # should be captured by model_config.quantization
        if self.compilation_config:
            vllm_factors.append(self.compilation_config.compute_hash())
3133
3134
        else:
            vllm_factors.append("None")
3135
3136
        if self.kv_transfer_config:
            vllm_factors.append(self.kv_transfer_config.compute_hash())
3137
3138
3139
3140
3141
3142
        else:
            vllm_factors.append("None")
        if self.additional_config:
            vllm_factors.append(self.additional_config.compute_hash())
        else:
            vllm_factors.append("None")
3143
3144
3145
3146
3147
        factors.append(vllm_factors)

        hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
        return hash_str

3148
3149
3150
3151
3152
3153
    def pad_for_cudagraph(self, batch_size: int) -> int:
        # if batch_size > self.compilation_config.max_capture_size,
        # it should raise an IndexError.
        # the caller should make sure the batch_size is within the range,
        # i.e., batch_size <= self.compilation_config.max_capture_size
        return self.compilation_config.bs_to_padded_graph_size[batch_size]
3154

3155
3156
3157
3158
3159
    @staticmethod
    def _get_quantization_config(
            model_config: ModelConfig,
            load_config: LoadConfig) -> Optional[QuantizationConfig]:
        """Get the quantization config."""
3160
        from vllm.platforms import current_platform
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
        if model_config.quantization is not None:
            from vllm.model_executor.model_loader.weight_utils import (
                get_quant_config)
            quant_config = get_quant_config(model_config, load_config)
            capability_tuple = current_platform.get_device_capability()

            if capability_tuple is not None:
                capability = capability_tuple.to_int()
                if capability < quant_config.get_min_capability():
                    raise ValueError(
                        f"The quantization method {model_config.quantization} "
                        "is not supported for the current GPU. Minimum "
                        f"capability: {quant_config.get_min_capability()}. "
                        f"Current capability: {capability}.")
            supported_dtypes = quant_config.get_supported_act_dtypes()
            if model_config.dtype not in supported_dtypes:
                raise ValueError(
                    f"{model_config.dtype} is not supported for quantization "
                    f"method {model_config.quantization}. Supported dtypes: "
                    f"{supported_dtypes}")
            return quant_config
        return None
3183

3184
3185
3186
3187
3188
3189
3190
3191
3192
    def with_hf_config(
        self,
        hf_config: PretrainedConfig,
        architectures: Optional[list[str]] = None,
    ) -> "VllmConfig":
        if architectures is not None:
            hf_config = copy.deepcopy(hf_config)
            hf_config.architectures = architectures

3193
3194
3195
3196
3197
        model_config = copy.deepcopy(self.model_config)
        model_config.hf_config = hf_config

        return replace(self, model_config=model_config)

3198
3199
3200
    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
3201
3202
3203
3204
3205
3206
3207
3208
        if self.model_config is not None:
            self.model_config.verify_async_output_proc(self.parallel_config,
                                                       self.speculative_config,
                                                       self.device_config)
            self.model_config.verify_with_parallel_config(self.parallel_config)

        if self.cache_config is not None:
            self.cache_config.verify_with_parallel_config(self.parallel_config)
3209
3210

        if self.lora_config:
3211
            self.lora_config.verify_with_cache_config(self.cache_config)
3212
3213
3214
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
3215
3216
3217
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
3218
3219
3220
3221
3222

        if self.quant_config is None and \
            self.model_config is not None and self.load_config is not None:
            self.quant_config = VllmConfig._get_quantization_config(
                self.model_config, self.load_config)
3223

3224
        from vllm.platforms import current_platform
3225
3226
3227
3228
3229
        if self.scheduler_config is not None and \
            self.model_config is not None and \
            self.scheduler_config.chunked_prefill_enabled and \
            self.model_config.dtype == torch.float32 and \
            current_platform.get_device_capability() == (7, 5):
3230
            logger.warning_once(
3231
3232
3233
3234
                "Turing devices tensor cores do not support float32 matmul. "
                "To workaround this limitation, vLLM will set 'ieee' input "
                "precision for chunked prefill triton kernels.")

3235
        if self.compilation_config is None:
3236
            self.compilation_config = CompilationConfig()
3237
3238
        if envs.VLLM_USE_V1 and self.model_config is not None and \
            not self.model_config.enforce_eager:
3239
3240
3241
3242
3243
3244
3245
            # NOTE(woosuk): Currently, we use inductor because the piecewise
            # CUDA graphs do not work properly with the custom CUDA kernels.
            # FIXME(woosuk): Disable inductor to reduce the compilation time
            # and avoid any potential issues with the inductor.
            self.compilation_config.custom_ops = ["none"]
            self.compilation_config.use_cudagraph = True
            self.compilation_config.use_inductor = True
3246
            self.compilation_config.cudagraph_num_of_warmups = 1
3247
3248
            self.compilation_config.pass_config.enable_fusion = False
            self.compilation_config.pass_config.enable_reshape = False
3249
            self.compilation_config.level = CompilationLevel.PIECEWISE
3250

3251
        self._set_cudagraph_sizes()
3252

3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
        if self.cache_config is not None and \
            self.cache_config.cpu_offload_gb > 0 and \
            self.compilation_config.level != CompilationLevel.NO_COMPILATION:
            logger.warning(
                "CPU offload is not supported with `torch.compile` yet."
                " Disabling `torch.compile`.")
            self.compilation_config.level = CompilationLevel.NO_COMPILATION

        if self.lora_config is not None and self.compilation_config.level !=\
             CompilationLevel.NO_COMPILATION:
            logger.warning("LoRA is not supported with `torch.compile` yet. "
                           "Disabling `torch.compile`.")
            self.compilation_config.level = CompilationLevel.NO_COMPILATION

3267
3268
        current_platform.check_and_update_config(self)

3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
        # If MLA is enabled, force disable chunked prefill and prefix caching
        if self.model_config and self.model_config.use_mla:
            logger.info("MLA is enabled; forcing chunked prefill and prefix "
                        "caching to be disabled.")
            self.scheduler_config.enable_chunked_prefill = False
            self.scheduler_config.chunked_prefill_enabled = False

            if self.cache_config is not None:
                self.cache_config.enable_prefix_caching = False

3279
3280
3281
        if not self.instance_id:
            self.instance_id = random_uuid()[:5]

3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
    def _set_cudagraph_sizes(self):
        """
        cudagraph batchsize padding logic:

        `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible
        batch sizes that cudagraph will capture.

        Depending on the engine's configuration of `max_num_seqs`, the
        candidate batch sizes to capture cudagraph will shrink to the subset
        which just cover the range of `[1, max_num_seqs]`. In the common case,
        `max_num_seqs` is 256, and the cudagraph batch sizes will be
        `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`.

        However, if users specify the cudagraph capture sizes through
        compilation config, we will use the specified sizes instead.

3298
3299
        In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
        will be the final sizes to capture cudagraph (in descending order).
3300
3301

        During runtime, if batchsize is larger than
3302
        `vllm_config.compilation_config.cudagraph_capture_sizes`,
3303
3304
        no cudagraph will be used.
        If the batch size is no larger than
3305
        `vllm_config.compilation_config.cudagraph_capture_sizes`,
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
        we can quickly find the padded graph size for a given batch size by
        looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
        """

        # calculate the default `batch_size_capture_list`
        if not envs.VLLM_USE_V1:
            batch_size_capture_list = []
            max_batchsize_to_capture = 0
            if self.scheduler_config is not None and \
                self.model_config is not None and \
                    not self.model_config.enforce_eager:

                possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
                # find the minimum size that is larger than max_num_seqs,
                # which then becomes the max_batchsize_to_capture
                larger_sizes = [
                    x for x in possible_sizes
                    if x >= self.scheduler_config.max_num_seqs
                ]
                if larger_sizes:
                    max_batchsize_to_capture = larger_sizes[0]
                else:
                    max_batchsize_to_capture = possible_sizes[-1]

                # filter out the sizes that are
                # larger than max_batchsize_to_capture
                batch_size_capture_list = [
                    size for size in possible_sizes
                    if size <= max_batchsize_to_capture
                ]
        else:
            batch_size_capture_list = []
            if self.model_config is not None and \
                not self.model_config.enforce_eager:
                batch_size_capture_list = [1, 2, 4
                                           ] + [i for i in range(8, 513, 8)]

        self.compilation_config.init_with_cudagraph_sizes(
            batch_size_capture_list)

3346
    def __str__(self):
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
        return (
            f"model={self.model_config.model!r},"
            f" speculative_config={self.speculative_config!r},"
            f" tokenizer={self.model_config.tokenizer!r}, "
            f"skip_tokenizer_init={self.model_config.skip_tokenizer_init},"
            f" tokenizer_mode={self.model_config.tokenizer_mode}, "
            f"revision={self.model_config.revision}, "
            f"override_neuron_config={self.model_config.override_neuron_config},"
            f" tokenizer_revision={self.model_config.tokenizer_revision}, "
            f"trust_remote_code={self.model_config.trust_remote_code}, "
            f"dtype={self.model_config.dtype}, "
            f"max_seq_len={self.model_config.max_model_len},"
            f" download_dir={self.load_config.download_dir!r}, "
            f"load_format={self.load_config.load_format}, "
            f"tensor_parallel_size={self.parallel_config.tensor_parallel_size},"
            f" pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, "  # noqa
            f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, "  # noqa
            f"quantization={self.model_config.quantization}, "
            f"enforce_eager={self.model_config.enforce_eager}, "
            f"kv_cache_dtype={self.cache_config.cache_dtype}, "
            f" device_config={self.device_config.device}, "
            f"decoding_config={self.decoding_config!r}, "
            f"observability_config={self.observability_config!r}, "
            f"seed={self.model_config.seed}, "
            f"served_model_name={self.model_config.served_model_name}, "
            f"num_scheduler_steps={self.scheduler_config.num_scheduler_steps}, "
            f"multi_step_stream_outputs={self.scheduler_config.multi_step_stream_outputs}, "  # noqa
            f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
            f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, "  # noqa
            f"use_async_output_proc={self.model_config.use_async_output_proc}, "
3377
            f"disable_mm_preprocessor_cache={self.model_config.disable_mm_preprocessor_cache!r}, "  # noqa
3378
            f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
3379
3380
            f"pooler_config={self.model_config.pooler_config!r}, "
            f"compilation_config={self.compilation_config!r}")
3381
3382
3383
3384
3385
3386


_current_vllm_config: Optional[VllmConfig] = None


@contextmanager
3387
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
    """
    Temporarily set the current VLLM config.
    Used during model initialization.
    We save the current VLLM config in a global variable,
    so that all modules can access it, e.g. custom ops
    can access the VLLM config to determine how to dispatch.
    """
    global _current_vllm_config
    old_vllm_config = _current_vllm_config
    from vllm.compilation.counter import compilation_counter
    num_models_seen = compilation_counter.num_models_seen
    try:
        _current_vllm_config = vllm_config
        yield
    finally:
        logger.debug("enabled custom ops: %s",
                     vllm_config.compilation_config.enabled_custom_ops)
        logger.debug("disabled custom ops: %s",
                     vllm_config.compilation_config.disabled_custom_ops)
3407
3408
        if check_compile and \
            vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
            and compilation_counter.num_models_seen == num_models_seen:
            # If the model supports compilation,
            # compilation_counter.num_models_seen should be increased
            # by at least 1.
            # If it is not increased, it means the model does not support
            # compilation (does not have @support_torch_compile decorator).
            logger.warning(
                "`torch.compile` is turned on, but the model %s"
                " does not support it. Please open an issue on GitHub"
                "if you want it to be supported.",
                vllm_config.model_config.model)
        _current_vllm_config = old_vllm_config


def get_current_vllm_config() -> VllmConfig:
    if _current_vllm_config is None:
        # in ci, usually when we test custom ops/modules directly,
        # we don't set the vllm config. In that case, we set a default
        # config.
        logger.warning("Current VLLM config is not set.")
        from vllm.config import VllmConfig
        return VllmConfig()
    return _current_vllm_config