arg_utils.py 101 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from collections.abc import Callable
11
from dataclasses import MISSING, asdict, dataclass, fields, is_dataclass
12
from itertools import permutations
13
from types import UnionType
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Literal,
19
    TypeAlias,
20
21
22
23
24
25
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
26

27
import huggingface_hub
28
import regex as re
29
import torch
30
from pydantic import TypeAdapter, ValidationError
31
from pydantic.fields import FieldInfo
32
from typing_extensions import TypeIs
33

34
import vllm.envs as envs
35
from vllm.config import (
36
    AttentionConfig,
37
38
39
40
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
41
    ECTransferConfig,
42
    EPLBConfig,
43
    KernelConfig,
44
45
46
47
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
48
    MambaConfig,
49
    ModelConfig,
50
    MultiModalConfig,
51
    ObservabilityConfig,
52
    OffloadConfig,
53
54
    ParallelConfig,
    PoolerConfig,
55
    PrefetchOffloadConfig,
56
    ProfilerConfig,
57
    ReasoningConfig,
58
59
60
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
61
    UVAOffloadConfig,
62
    VllmConfig,
63
    WeightTransferConfig,
64
65
    get_attr_docs,
)
66
67
68
from vllm.config.cache import (
    CacheDType,
    KVOffloadingBackend,
69
    MambaCacheMode,
70
71
72
    MambaDType,
    PrefixCachingHashAlgo,
)
73
from vllm.config.device import Device
74
from vllm.config.kernel import IrOpPriorityConfig, MoEBackend
75
from vllm.config.lora import MaxLoRARanks
76
from vllm.config.mamba import MambaBackendEnum
77
78
79
80
81
82
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
83
    TokenizerMode,
84
)
85
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MMTensorIPC
86
from vllm.config.observability import DetailedTraceModules
87
88
89
from vllm.config.parallel import (
    All2AllBackend,
    DataParallelBackend,
90
    DCPCommBackend,
91
92
93
    DistributedExecutorBackend,
    ExpertPlacementStrategy,
)
94
from vllm.config.scheduler import SchedulerPolicy
95
from vllm.config.utils import get_field
96
from vllm.config.vllm import OptimizationLevel, PerformanceMode
97
from vllm.logger import init_logger, suppress_logging
98
from vllm.platforms import CpuArchEnum, current_platform
99
from vllm.plugins import load_general_plugins
100
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
101
102
103
104
from vllm.transformers_utils.config import (
    is_interleaved,
    maybe_override_with_speculators,
)
105
from vllm.transformers_utils.gguf_utils import is_gguf
106
from vllm.transformers_utils.repo_utils import get_model_path
107
from vllm.transformers_utils.utils import is_cloud_storage
108
109
110
111
112
from vllm.utils.argparse_utils import (
    FlexibleArgumentParser,
    human_readable_int,
    human_readable_int_or_auto,
)
113
from vllm.utils.mem_constants import GiB_bytes
114
from vllm.utils.network_utils import get_ip
115
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
116
from vllm.v1.attention.backends.registry import AttentionBackendEnum
117
from vllm.v1.sample.logits_processor import LogitsProcessor
118
from vllm.version import __version__ as VLLM_VERSION
119

120
if TYPE_CHECKING:
121
    from vllm.config.quantization import OnlineQuantizationConfigArgs
122
    from vllm.model_executor.layers.quantization import QuantizationMethods
123
    from vllm.model_executor.model_loader import LoadFormats
124
    from vllm.usage.usage_lib import UsageContext
125
    from vllm.v1.executor import Executor
126
else:
127
    Executor = Any
128
    QuantizationMethods = Any
129
    LoadFormats = Any
130
131
    UsageContext = Any

132

133
134
logger = init_logger(__name__)

135
136
# object is used to allow for special typing forms
T = TypeVar("T")
137
138
TypeHint: TypeAlias = type[Any] | object
TypeHintT: TypeAlias = type[T] | object
139

140

141
142
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
143
144
145
146
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
147
148
                f"Value {val} cannot be converted to {return_type}."
            ) from e
149

150
151
152
    return _parse_type


153
154
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]:
    def _optional_type(val: str) -> T | None:
155
156
157
158
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

159
    return _optional_type
160
161


162
def union_dict_and_str(val: str) -> str | dict[str, str] | None:
163
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
164
        return str(val)
165
    return optional_type(json.loads)(val)
166
167


168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


183
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
184
185
186
187
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
188
    type_hint = get_type(type_hints, Literal)
189
190
191
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
192
        raise ValueError(
193
            "All options must be of the same type. "
194
195
            f"Got {options} with types {[type(c) for c in options]}"
        )
196
197
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
198
199


200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]:
    type_hint = get_type(type_hints, type)
    types = get_args(type_hint)
    elem_type = types[0]

    # Handle Ellipsis
    assert all(t is elem_type for t in types if t is not Ellipsis), (
        f"All non-Ellipsis elements must be of the same type. Got {types}."
    )

    # Handle Union types
    if get_origin(elem_type) in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
        assert str in get_args(elem_type), (
            "If element can have multiple types, one must be 'str' "
            f"(i.e. 'list[int | str]'). Got {elem_type}."
        )
        elem_type = str

    return {
        "type": elem_type,
        "nargs": "+" if type is not tuple or Ellipsis in types else len(types),
    }


225
226
227
228
229
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


230
231
232
233
234
235
236
237
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
238
239
    elif origin in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
240
241
242
243
244
245
246
247
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


248
NEEDS_HELP = (
249
250
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
251
252
253
254
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


255
256
257
258
259
260
261
262
def _maybe_add_docs_url(cls: Any) -> str:
    """Generate API docs URL for a vllm config class."""
    if not cls.__module__.startswith("vllm.config"):
        return ""
    version = f"v{VLLM_VERSION}" if "dev" not in VLLM_VERSION else "latest"
    return f"\n\nAPI docs: https://docs.vllm.ai/en/{version}/api/vllm/config/#vllm.config.{cls.__name__}"


263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def _expand_json_human_readable_numbers(val: str) -> str:
    """Expand human-readable number suffixes in a JSON string.

    Based on :func:`human_readable_int` so that the ``k/m/g/t`` (decimal) and
    ``K/M/G/T`` (binary) conventions work out the box.
    Also works inside JSON config arguments such
    as ``--kv-transfer-config '{"cpu_bytes_to_use": 80m}'``.

    Only bare (unquoted) tokens are replaced so that JSON string values
    like ``"model_name"`` are never modified.
    """
    # Split on quoted strings so we only touch non-string regions.
    parts = re.split(r'("(?:[^"\\]|\\.)*")', val)
    for i in range(0, len(parts), 2):  # even indices = outside strings
        parts[i] = re.sub(
            r"\b\d+(?:\.\d+)?[kKmMgGtT]\b",
            lambda m: str(human_readable_int(m.group())),
            parts[i],
        )
    return "".join(parts)


285
@functools.lru_cache(maxsize=30)
286
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
287
288
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
289
290
    kwargs = {}
    for field in fields(cls):
291
        # Get the set of possible types for the field
292
        type_hints: set[TypeHint] = get_type_hints(field.type)
293
294
295
296
297

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

298
        # Get the default value of the field
299
300
        if field.default is not MISSING:
            default = field.default
301
302
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
303
304
305
306
307
308
                if default.default_factory is None:
                    default = default.default
                else:
                    # VllmConfig's Fields have default_factory set to config classes.
                    # These could emit logs on init, which would be confusing.
                    with suppress_logging():
309
                        default = default.default_factory()  # type: ignore[call-arg]
310
        elif field.default_factory is not MISSING:
311
            default = field.default_factory()
312
313
314

        # Get the help text for the field
        name = field.name
315
        help = cls_docs.get(name, "").strip()
316
317
318
319
320
321
322
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
323
324
325
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
326
        if dataclass_cls is not None:
327
328
329

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
330
                    val = _expand_json_human_readable_numbers(val)
331
332
333
334
335
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
336
            kwargs[name]["help"] += _maybe_add_docs_url(dataclass_cls)
337
            kwargs[name]["help"] += f"\n\n{json_tip}"
338
        elif contains_type(type_hints, bool):
339
340
341
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
342
            kwargs[name].update(literal_to_kwargs(type_hints))
343
        elif contains_type(type_hints, tuple):
344
            kwargs[name].update(collection_to_kwargs(type_hints, tuple))
345
        elif contains_type(type_hints, list):
346
347
348
            kwargs[name].update(collection_to_kwargs(type_hints, list))
        elif contains_type(type_hints, set):
            kwargs[name].update(collection_to_kwargs(type_hints, set))
349
        elif contains_type(type_hints, int):
350
351
352
353
            if name == "max_model_len":
                kwargs[name]["type"] = human_readable_int_or_auto
                kwargs[name]["help"] += f"\n\n{human_readable_int_or_auto.__doc__}"
            elif name in ("max_num_batched_tokens", "kv_cache_memory_bytes"):
354
                kwargs[name]["type"] = human_readable_int
355
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
356
357
            else:
                kwargs[name]["type"] = int
358
359
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
360
361
362
363
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
364
            kwargs[name]["type"] = union_dict_and_str
365
        elif contains_type(type_hints, dict):
366
            kwargs[name]["type"] = parse_type(json.loads)
367
            kwargs[name]["help"] += f"\n\n{json_tip}"
368
369
370
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
371
372
            kwargs[name]["type"] = str
        else:
373
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
374

375
376
377
378
379
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

380
381
382
383
384
385
386
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
387
388


389
def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
390
391
    """Return argparse kwargs for the given Config dataclass.

392
393
394
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

395
396
397
398
399
400
401
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


402
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
403
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
404
    """Arguments for vLLM engine."""
405

406
    model: str = ModelConfig.model
407
    enable_return_routed_experts: bool = ModelConfig.enable_return_routed_experts
408
    model_weights: str = ModelConfig.model_weights
409
    served_model_name: str | list[str] | None = ModelConfig.served_model_name
410
    tokenizer: str | None = ModelConfig.tokenizer
411
    hf_config_path: str | None = ModelConfig.hf_config_path
412
413
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
414
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
415
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
416
    tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
417
    trust_remote_code: bool = ModelConfig.trust_remote_code
418
419
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
    allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
420
    download_dir: str | None = LoadConfig.download_dir
421
    safetensors_load_strategy: str | None = LoadConfig.safetensors_load_strategy
422
    load_format: str | LoadFormats = LoadConfig.load_format
423
424
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
425
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
426
    seed: int = ModelConfig.seed
427
    max_model_len: int = ModelConfig.max_model_len
428
429
430
431
432
433
    cudagraph_capture_sizes: list[int] | None = (
        CompilationConfig.cudagraph_capture_sizes
    )
    max_cudagraph_capture_size: int | None = get_field(
        CompilationConfig, "max_cudagraph_capture_size"
    )
434
    ir_op_priority: IrOpPriorityConfig = get_field(KernelConfig, "ir_op_priority")
435
436
437
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
438
    distributed_executor_backend: (
439
        str | DistributedExecutorBackend | type[Executor] | None
440
    ) = ParallelConfig.distributed_executor_backend
441
    # number of P/D disaggregation (or other disaggregation) workers
442
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
443
444
445
446
    master_addr: str = ParallelConfig.master_addr
    master_port: int = ParallelConfig.master_port
    nnodes: int = ParallelConfig.nnodes
    node_rank: int = ParallelConfig.node_rank
447
    distributed_timeout_seconds: int | None = ParallelConfig.distributed_timeout_seconds
448
449
450
    numa_bind: bool = ParallelConfig.numa_bind
    numa_bind_nodes: list[int] | None = ParallelConfig.numa_bind_nodes
    numa_bind_cpus: list[str] | None = ParallelConfig.numa_bind_cpus
451
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
452
    prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
453
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
454
    dcp_comm_backend: DCPCommBackend = ParallelConfig.dcp_comm_backend
455
    dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
456
    cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
457
    data_parallel_size: int = ParallelConfig.data_parallel_size
458
459
460
461
462
    data_parallel_rank: int | None = None
    data_parallel_start_rank: int | None = None
    data_parallel_size_local: int | None = None
    data_parallel_address: str | None = None
    data_parallel_rpc_port: int | None = None
463
    data_parallel_hybrid_lb: bool = False
464
    data_parallel_external_lb: bool = False
465
    data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend
466
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
467
    enable_ep_weight_filter: bool = ParallelConfig.enable_ep_weight_filter
468
    moe_backend: MoEBackend = KernelConfig.moe_backend
469
    all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
470
    enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
471
    enable_dbo: bool = ParallelConfig.enable_dbo
472
    ubatch_size: int = ParallelConfig.ubatch_size
473
474
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
475
    disable_nccl_for_dp_synchronization: bool | None = (
476
477
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
478
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
479
    enable_eplb: bool = ParallelConfig.enable_eplb
480
    expert_placement_strategy: ExpertPlacementStrategy = (
481
        ParallelConfig.expert_placement_strategy
482
    )
483
484
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
485
    max_parallel_loading_workers: int | None = (
486
487
        ParallelConfig.max_parallel_loading_workers
    )
488
    block_size: int | None = None
489
    enable_prefix_caching: bool | None = None
490
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
491
        CacheConfig.prefix_caching_hash_algo
492
    )
493
494
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
495
496
497
498
499
500
501
    offload_backend: str = OffloadConfig.offload_backend
    cpu_offload_gb: float = UVAOffloadConfig.cpu_offload_gb
    cpu_offload_params: set[str] = get_field(UVAOffloadConfig, "cpu_offload_params")
    offload_group_size: int = PrefetchOffloadConfig.offload_group_size
    offload_num_in_group: int = PrefetchOffloadConfig.offload_num_in_group
    offload_prefetch_step: int = PrefetchOffloadConfig.offload_prefetch_step
    offload_params: set[str] = get_field(PrefetchOffloadConfig, "offload_params")
502
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
503
    kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
504
    max_num_batched_tokens: int | None = None
505
506
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
507
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
508
    max_num_seqs: int | None = None
509
    max_logprobs: int = ModelConfig.max_logprobs
510
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
511
    disable_log_stats: bool = False
512
    aggregate_engine_logging: bool = False
513
514
515
    revision: str | None = ModelConfig.revision
    code_revision: str | None = ModelConfig.code_revision
    hf_token: bool | str | None = ModelConfig.hf_token
516
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
517
    tokenizer_revision: str | None = ModelConfig.tokenizer_revision
518
    quantization: QuantizationMethods | str | None = ModelConfig.quantization
519
    quantization_config: "dict[str, Any] | OnlineQuantizationConfigArgs | None" = None
520
    allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
521
    enforce_eager: bool = ModelConfig.enforce_eager
522
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
523
    language_model_only: bool = MultiModalConfig.language_model_only
524
    limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
525
526
        MultiModalConfig, "limit_per_prompt"
    )
527
    enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
528
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
529
530
531
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
532
    mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
533
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
534
    mm_processor_cache_type: MMCacheType | None = (
535
        MultiModalConfig.mm_processor_cache_type
536
537
    )
    mm_shm_cache_max_object_size_mb: int = (
538
        MultiModalConfig.mm_shm_cache_max_object_size_mb
539
    )
540
    mm_encoder_only: bool = MultiModalConfig.mm_encoder_only
541
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
542
    mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
543
544
        MultiModalConfig.mm_encoder_attn_backend
    )
545
    io_processor_plugin: str | None = None
546
    renderer_num_workers: int = 1
547
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
548
    video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
549
    mm_tensor_ipc: MMTensorIPC = MultiModalConfig.mm_tensor_ipc
550
    # LoRA fields
551
    enable_lora: bool = False
552
    max_loras: int = LoRAConfig.max_loras
553
    max_lora_rank: MaxLoRARanks = LoRAConfig.max_lora_rank
554
    default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras
555
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
556
557
    max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
    lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
558
    lora_target_modules: list[str] | None = LoRAConfig.target_modules
559
    enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
560
    specialize_active_lora: bool = LoRAConfig.specialize_active_lora
561

562
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
563
    num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
564
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
565
    ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
566

567
    enable_chunked_prefill: bool | None = None
568
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
569

570
571
    scheduler_reserve_full_isl: bool = SchedulerConfig.scheduler_reserve_full_isl

572
    disable_hybrid_kv_cache_manager: bool | None = (
573
574
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
575

576
    structured_outputs_config: StructuredOutputsConfig = get_field(
577
578
        VllmConfig, "structured_outputs_config"
    )
579
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
580
    reasoning_parser_plugin: str | None = None
581

582
    speculative_config: dict[str, Any] | None = None
583

584
    show_hidden_metrics_for_version: str | None = (
585
        ObservabilityConfig.show_hidden_metrics_for_version
586
    )
587
588
    otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: list[DetailedTraceModules] | None = (
589
        ObservabilityConfig.collect_detailed_traces
590
    )
591
592
593
594
    kv_cache_metrics: bool = ObservabilityConfig.kv_cache_metrics
    kv_cache_metrics_sample: float = get_field(
        ObservabilityConfig, "kv_cache_metrics_sample"
    )
595
    cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
596
597
598
    enable_layerwise_nvtx_tracing: bool = (
        ObservabilityConfig.enable_layerwise_nvtx_tracing
    )
599
    enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
600
601
602
    enable_logging_iteration_details: bool = (
        ObservabilityConfig.enable_logging_iteration_details
    )
603
    enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
604
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
605
    scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
606

607
    pooler_config: PoolerConfig | None = ModelConfig.pooler_config
608
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
609
    attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
610
    mamba_config: MambaConfig = get_field(VllmConfig, "mamba_config")
611
612
613
614
    kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
    enable_flashinfer_autotune: bool = get_field(
        KernelConfig, "enable_flashinfer_autotune"
    )
615
616
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
617

618
619
    profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")

620
621
    kv_transfer_config: KVTransferConfig | None = None
    kv_events_config: KVEventsConfig | None = None
622

623
    ec_transfer_config: ECTransferConfig | None = None
624
    reasoning_config: ReasoningConfig = get_field(VllmConfig, "reasoning_config")
625

626
627
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
628
629
630
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
631
    model_impl: str = ModelConfig.model_impl
632
    override_attention_dtype: str | None = ModelConfig.override_attention_dtype
633
    attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
634

635
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
636
637
638
    kv_cache_dtype_skip_layers: list[str] = get_field(
        CacheConfig, "kv_cache_dtype_skip_layers"
    )
639
640
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
641
    mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
642
    mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
643
644

    mamba_backend: MambaBackendEnum = MambaBackendEnum.TRITON
645
    enable_mamba_cache_stochastic_rounding: bool = (
646
        MambaConfig.enable_stochastic_rounding
647
    )
648
    mamba_cache_philox_rounds: int = MambaConfig.stochastic_rounding_philox_rounds
649

650
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
651

652
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
653
    pt_load_map_location: str | dict[str, str] = LoadConfig.pt_load_map_location
654

655
    logits_processors: list[str | type[LogitsProcessor]] | None = (
656
657
        ModelConfig.logits_processors
    )
658
659
    """Custom logitproc types"""

660
    async_scheduling: bool | None = SchedulerConfig.async_scheduling
661

662
663
    stream_interval: int = SchedulerConfig.stream_interval

664
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
665
    optimization_level: OptimizationLevel = VllmConfig.optimization_level
666
    performance_mode: PerformanceMode = VllmConfig.performance_mode
667

668
    kv_offloading_size: float | None = CacheConfig.kv_offloading_size
669
    kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
670
    tokens_only: bool = False
671

672
673
    shutdown_timeout: int = 0

674
675
676
677
    weight_transfer_config: WeightTransferConfig | None = get_field(
        VllmConfig,
        "weight_transfer_config",
    )
678

679
    fail_on_environ_validation: bool = False
680
    gdn_prefill_backend: Literal["flashinfer", "triton"] | None = None
681

682
    def __post_init__(self):
683
684
685
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
686
        if isinstance(self.compilation_config, dict):
687
            self.compilation_config = CompilationConfig(**self.compilation_config)
688
689
        if isinstance(self.attention_config, dict):
            self.attention_config = AttentionConfig(**self.attention_config)
690
691
        if isinstance(self.mamba_config, dict):
            self.mamba_config = MambaConfig(**self.mamba_config)
692
693
        if isinstance(self.kernel_config, dict):
            self.kernel_config = KernelConfig(**self.kernel_config)
694
        if isinstance(self.eplb_config, dict):
695
            self.eplb_config = EPLBConfig(**self.eplb_config)
696
697
698
699
        if isinstance(self.weight_transfer_config, dict):
            self.weight_transfer_config = WeightTransferConfig(
                **self.weight_transfer_config
            )
700
701
702
        if isinstance(self.ir_op_priority, dict):
            self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority)

703
704
705
706
707
708
        from vllm.config.quantization import resolve_online_quant_config

        self.quantization_config = resolve_online_quant_config(
            self.quantization, self.quantization_config
        )

709
        # Setup plugins
710
        from vllm.plugins import load_general_plugins
711

712
        load_general_plugins()
713
        # when use hf offline,replace model and tokenizer id to local model path
714
715
716
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
            if model_id is not self.model:
                logger.info(
                    "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                    model_id,
                    self.model,
                )
            if self.tokenizer is not None:
                tokenizer_id = self.tokenizer
                self.tokenizer = get_model_path(self.tokenizer, self.tokenizer_revision)
                if tokenizer_id is not self.tokenizer:
                    logger.info(
                        "HF_HUB_OFFLINE is True, replace tokenizer_id [%s] "
                        "to tokenizer_path [%s]",
                        tokenizer_id,
                        self.tokenizer,
                    )
733
734

    @staticmethod
735
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
736
        """Shared CLI arguments for vLLM engine."""
737

738
        # Model arguments
739
740
741
742
743
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
744
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
745
            model_group.add_argument("--model", **model_kwargs["model"])
746
747
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
748
749
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
750
751
752
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
753
754
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
755
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
756
757
758
759
760
761
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
762
        model_group.add_argument("--revision", **model_kwargs["revision"])
763
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
764
765
766
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
767
768
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
769
770
771
772
        model_group.add_argument(
            "--allow-deprecated-quantization",
            **model_kwargs["allow_deprecated_quantization"],
        )
773
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
774
775
776
777
        model_group.add_argument(
            "--enable-return-routed-experts",
            **model_kwargs["enable_return_routed_experts"],
        )
778
779
780
781
782
783
784
785
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
786
787
788
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
789
790
791
792
793
794
795
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
796
797
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
817
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
818
819
820
821
822
823
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
824
825
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
826
        )
827
828
829
830
        model_group.add_argument(
            "--renderer-num-workers",
            **model_kwargs["renderer_num_workers"],
        )
831

832
833
834
835
836
837
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
838
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
839
840
841
842
843
844
845
846
847
848
849
850
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
851

852
853
854
855
856
857
858
859
860
861
        # Attention arguments
        attention_kwargs = get_kwargs(AttentionConfig)
        attention_group = parser.add_argument_group(
            title="AttentionConfig",
            description=AttentionConfig.__doc__,
        )
        attention_group.add_argument(
            "--attention-backend", **attention_kwargs["backend"]
        )

862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
        # Mamba arguments
        mamba_kwargs = get_kwargs(MambaConfig)
        mamba_group = parser.add_argument_group(
            title="MambaConfig",
            description=MambaConfig.__doc__,
        )
        mamba_group.add_argument("--mamba-backend", **mamba_kwargs["backend"])
        mamba_group.add_argument(
            "--enable-mamba-cache-stochastic-rounding",
            **mamba_kwargs["enable_stochastic_rounding"],
        )
        mamba_group.add_argument(
            "--mamba-cache-philox-rounds",
            **mamba_kwargs["stochastic_rounding_philox_rounds"],
        )

878
879
880
881
882
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
883
        )
884
        structured_outputs_group.add_argument(
885
            "--reasoning-parser",
886
            # Choices need to be validated after parsing to include plugins
887
888
            **structured_outputs_kwargs["reasoning_parser"],
        )
889
890
891
892
        structured_outputs_group.add_argument(
            "--reasoning-parser-plugin",
            **structured_outputs_kwargs["reasoning_parser_plugin"],
        )
893

894
        # Parallel arguments
895
896
897
898
899
900
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
901
            "--distributed-executor-backend",
902
903
            **parallel_kwargs["distributed_executor_backend"],
        )
904
        parallel_group.add_argument(
905
906
907
908
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
909
910
911
912
        parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
        parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
        parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
        parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
913
914
915
916
        parallel_group.add_argument(
            "--distributed-timeout-seconds",
            **parallel_kwargs["distributed_timeout_seconds"],
        )
917
918
919
920
921
922
923
        parallel_group.add_argument("--numa-bind", **parallel_kwargs["numa_bind"])
        parallel_group.add_argument(
            "--numa-bind-nodes", **parallel_kwargs["numa_bind_nodes"]
        )
        parallel_group.add_argument(
            "--numa-bind-cpus", **parallel_kwargs["numa_bind_cpus"]
        )
924
        parallel_group.add_argument(
925
926
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
927
        parallel_group.add_argument(
928
929
930
931
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
932
933
934
935
        parallel_group.add_argument(
            "--dcp-comm-backend",
            **parallel_kwargs["dcp_comm_backend"],
        )
936
937
938
939
        parallel_group.add_argument(
            "--dcp-kv-cache-interleave-size",
            **parallel_kwargs["dcp_kv_cache_interleave_size"],
        )
940
941
942
943
944
945
946
947
948
        parallel_group.add_argument(
            "--cp-kv-cache-interleave-size",
            **parallel_kwargs["cp_kv_cache_interleave_size"],
        )
        parallel_group.add_argument(
            "--prefill-context-parallel-size",
            "-pcp",
            **parallel_kwargs["prefill_context_parallel_size"],
        )
949
950
951
952
953
954
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
955
            type=int,
956
957
958
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
959
        parallel_group.add_argument(
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
990
        parallel_group.add_argument(
991
992
993
994
995
996
997
998
            "--data-parallel-hybrid-lb",
            "-dph",
            **parallel_kwargs["data_parallel_hybrid_lb"],
        )
        parallel_group.add_argument(
            "--data-parallel-external-lb",
            "-dpe",
            **parallel_kwargs["data_parallel_external_lb"],
999
1000
        )
        parallel_group.add_argument(
1001
1002
1003
            "--enable-expert-parallel",
            "-ep",
            **parallel_kwargs["enable_expert_parallel"],
1004
        )
1005
1006
1007
1008
        parallel_group.add_argument(
            "--enable-ep-weight-filter",
            **parallel_kwargs["enable_ep_weight_filter"],
        )
1009
1010
1011
        parallel_group.add_argument(
            "--all2all-backend", **parallel_kwargs["all2all_backend"]
        )
1012
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
1013
1014
1015
1016
        parallel_group.add_argument(
            "--ubatch-size",
            **parallel_kwargs["ubatch_size"],
        )
1017
1018
1019
        parallel_group.add_argument(
            "--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"]
        )
1020
1021
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
1022
1023
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
1024
1025
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
1026
1027
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
1028
1029
1030
1031
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
1032
1033
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
1034
1035
        parallel_group.add_argument(
            "--expert-placement-strategy",
1036
1037
            **parallel_kwargs["expert_placement_strategy"],
        )
1038

1039
        parallel_group.add_argument(
1040
            "--max-parallel-loading-workers",
1041
1042
            **parallel_kwargs["max_parallel_loading_workers"],
        )
1043
        parallel_group.add_argument(
1044
1045
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
1046
        parallel_group.add_argument(
1047
            "--disable-custom-all-reduce",
1048
1049
1050
1051
1052
1053
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
1054

1055
1056
1057
1058
1059
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
1060
        )
1061
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
1073
1074
1075
1076
1077
            "--enable-prefix-caching",
            **{
                **cache_kwargs["enable_prefix_caching"],
                "default": None,
            },
1078
1079
1080
1081
1082
1083
1084
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
1085
1086
1087
        cache_group.add_argument(
            "--kv-cache-dtype-skip-layers", **cache_kwargs["kv_cache_dtype_skip_layers"]
        )
1088
1089
1090
1091
1092
1093
1094
1095
1096
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
1097
1098
1099
        cache_group.add_argument(
            "--mamba-block-size", **cache_kwargs["mamba_block_size"]
        )
1100
1101
1102
        cache_group.add_argument(
            "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
        )
1103
1104
1105
1106
1107
1108
        cache_group.add_argument(
            "--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
        )
        cache_group.add_argument(
            "--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
        )
1109

1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
        # Model weight offload related configs
        offload_kwargs = get_kwargs(OffloadConfig)
        uva_kwargs = get_kwargs(UVAOffloadConfig)
        prefetch_kwargs = get_kwargs(PrefetchOffloadConfig)
        offload_group = parser.add_argument_group(
            title="OffloadConfig",
            description=OffloadConfig.__doc__,
        )
        offload_group.add_argument(
            "--offload-backend", **offload_kwargs["offload_backend"]
        )
        offload_group.add_argument("--cpu-offload-gb", **uva_kwargs["cpu_offload_gb"])
        offload_group.add_argument(
            "--cpu-offload-params", **uva_kwargs["cpu_offload_params"]
        )
        offload_group.add_argument(
            "--offload-group-size",
            **prefetch_kwargs["offload_group_size"],
        )
        offload_group.add_argument(
            "--offload-num-in-group",
            **prefetch_kwargs["offload_num_in_group"],
        )
        offload_group.add_argument(
            "--offload-prefetch-step",
            **prefetch_kwargs["offload_prefetch_step"],
        )
        offload_group.add_argument(
            "--offload-params", **prefetch_kwargs["offload_params"]
        )

1141
        # Multimodal related configs
1142
1143
1144
1145
1146
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
1147
1148
1149
        multimodal_group.add_argument(
            "--language-model-only", **multimodal_kwargs["language_model_only"]
        )
1150
        multimodal_group.add_argument(
1151
1152
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
1153
1154
1155
        multimodal_group.add_argument(
            "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
        )
1156
1157
1158
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
1159
1160
1161
1162
1163
1164
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
1165
        multimodal_group.add_argument(
1166
1167
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
1168
1169
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
1170
1171
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
1172
1173
1174
        multimodal_group.add_argument(
            "--mm-encoder-only", **multimodal_kwargs["mm_encoder_only"]
        )
1175
        multimodal_group.add_argument(
1176
1177
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
1178
1179
1180
1181
        multimodal_group.add_argument(
            "--mm-encoder-attn-backend",
            **multimodal_kwargs["mm_encoder_attn_backend"],
        )
1182
1183
1184
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
1185
        multimodal_group.add_argument(
1186
1187
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
1188

1189
        multimodal_group.add_argument(
1190
1191
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
1192
1193
1194
        multimodal_group.add_argument(
            "--mm-tensor-ipc", **multimodal_kwargs["mm_tensor_ipc"]
        )
1195

1196
        # LoRA related configs
1197
1198
1199
1200
1201
1202
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
1203
            "--enable-lora",
1204
            action=argparse.BooleanOptionalAction,
1205
1206
            help="If True, enable handling of LoRA adapters.",
        )
1207
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
1208
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
1209
        lora_group.add_argument(
1210
            "--lora-dtype",
1211
1212
            **lora_kwargs["lora_dtype"],
        )
1213
1214
1215
1216
        lora_group.add_argument(
            "--enable-tower-connector-lora",
            **lora_kwargs["enable_tower_connector_lora"],
        )
1217
1218
1219
1220
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
1221
1222
1223
        lora_group.add_argument(
            "--lora-target-modules", **lora_kwargs["target_modules"]
        )
1224
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
1225
1226
1227
        lora_group.add_argument(
            "--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
        )
1228

1229
1230
1231
1232
1233
1234
1235
1236
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
1237
1238
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
1239
        observability_group.add_argument(
1240
1241
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
1242
1243
1244
1245
1246
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
1247
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
1248
1249
1250
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
1251
1252
            **observability_kwargs["collect_detailed_traces"],
        )
1253
1254
1255
1256
1257
1258
1259
        observability_group.add_argument(
            "--kv-cache-metrics", **observability_kwargs["kv_cache_metrics"]
        )
        observability_group.add_argument(
            "--kv-cache-metrics-sample",
            **observability_kwargs["kv_cache_metrics_sample"],
        )
1260
1261
1262
1263
        observability_group.add_argument(
            "--cudagraph-metrics",
            **observability_kwargs["cudagraph_metrics"],
        )
1264
1265
1266
1267
        observability_group.add_argument(
            "--enable-layerwise-nvtx-tracing",
            **observability_kwargs["enable_layerwise_nvtx_tracing"],
        )
1268
1269
1270
1271
        observability_group.add_argument(
            "--enable-mfu-metrics",
            **observability_kwargs["enable_mfu_metrics"],
        )
1272
1273
1274
1275
        observability_group.add_argument(
            "--enable-logging-iteration-details",
            **observability_kwargs["enable_logging_iteration_details"],
        )
1276

1277
1278
1279
1280
1281
1282
1283
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
1284
1285
1286
1287
1288
            "--max-num-batched-tokens",
            **{
                **scheduler_kwargs["max_num_batched_tokens"],
                "default": None,
            },
1289
        )
1290
        scheduler_group.add_argument(
1291
1292
1293
1294
1295
            "--max-num-seqs",
            **{
                **scheduler_kwargs["max_num_seqs"],
                "default": None,
            },
1296
1297
1298
1299
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
1300
1301
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
1302
1303
            **scheduler_kwargs["max_long_partial_prefills"],
        )
1304
1305
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
1306
1307
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
1308
1309
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
1310
        scheduler_group.add_argument(
1311
1312
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
1313
        scheduler_group.add_argument(
1314
1315
1316
1317
1318
            "--enable-chunked-prefill",
            **{
                **scheduler_kwargs["enable_chunked_prefill"],
                "default": None,
            },
1319
1320
1321
1322
1323
1324
1325
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1326
1327
1328
1329
        scheduler_group.add_argument(
            "--scheduler-reserve-full-isl",
            **scheduler_kwargs["scheduler_reserve_full_isl"],
        )
1330
1331
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1332
1333
1334
1335
1336
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1337
1338
1339
        scheduler_group.add_argument(
            "--stream-interval", **scheduler_kwargs["stream_interval"]
        )
1340

1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
        # Compilation arguments
        compilation_kwargs = get_kwargs(CompilationConfig)
        compilation_group = parser.add_argument_group(
            title="CompilationConfig",
            description=CompilationConfig.__doc__,
        )
        compilation_group.add_argument(
            "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
        )
        compilation_group.add_argument(
            "--max-cudagraph-capture-size",
            **compilation_kwargs["max_cudagraph_capture_size"],
        )

1355
1356
1357
1358
1359
1360
        # Kernel arguments
        kernel_kwargs = get_kwargs(KernelConfig)
        kernel_group = parser.add_argument_group(
            title="KernelConfig",
            description=KernelConfig.__doc__,
        )
1361
        kernel_group.add_argument("--ir-op-priority", **kernel_kwargs["ir_op_priority"])
1362
1363
1364
1365
        kernel_group.add_argument(
            "--enable-flashinfer-autotune",
            **kernel_kwargs["enable_flashinfer_autotune"],
        )
1366
1367
1368
        moe_backend_kwargs = kernel_kwargs["moe_backend"]
        moe_backend_kwargs["type"] = lambda s: s.lower().replace("-", "_")
        kernel_group.add_argument("--moe-backend", **moe_backend_kwargs)
1369

1370
        # vLLM arguments
1371
        vllm_kwargs = get_kwargs(VllmConfig)
1372
1373
1374
1375
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1376
1377
1378
1379
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1380
        vllm_group.add_argument(
1381
            "--speculative-config", "-sc", **vllm_kwargs["speculative_config"]
1382
1383
1384
1385
1386
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
1387
1388
1389
        vllm_group.add_argument(
            "--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
        )
1390
        vllm_group.add_argument(
1391
            "--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
1392
        )
1393
1394
1395
        vllm_group.add_argument(
            "--attention-config", "-ac", **vllm_kwargs["attention_config"]
        )
1396
        vllm_group.add_argument("--reasoning-config", **vllm_kwargs["reasoning_config"])
1397
        vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
1398
1399
1400
1401
1402
1403
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1404
        vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
1405
1406
1407
        vllm_group.add_argument(
            "--optimization-level", **vllm_kwargs["optimization_level"]
        )
1408
        vllm_group.add_argument("--performance-mode", **vllm_kwargs["performance_mode"])
1409
1410
1411
        vllm_group.add_argument(
            "--weight-transfer-config", **vllm_kwargs["weight_transfer_config"]
        )
1412

1413
        # Other arguments
1414
1415
1416
1417
1418
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1419

1420
1421
1422
1423
1424
1425
        parser.add_argument(
            "--aggregate-engine-logging",
            action="store_true",
            help="Log aggregate rather than per-engine statistics "
            "when using data parallelism.",
        )
1426
1427
1428
1429
1430
1431
1432
1433

        parser.add_argument(
            "--fail-on-environ-validation",
            help="If set, the engine will raise an error if "
            "environment validation fails.",
            default=False,
            action=argparse.BooleanOptionalAction,
        )
1434
1435
1436
1437
1438
1439
1440
1441

        parser.add_argument(
            "--shutdown-timeout",
            type=int,
            default=0,
            help="Shutdown timeout in seconds. 0 = abort, >0 = wait.",
        )

1442
1443
1444
1445
1446
1447
1448
        parser.add_argument(
            "--gdn-prefill-backend",
            dest="gdn_prefill_backend",
            choices=["flashinfer", "triton"],
            default=None,
            help="Select GDN prefill backend.",
        )
1449
        return parser
1450
1451

    @classmethod
1452
    def from_cli_args(cls, args: argparse.Namespace):
1453
1454
1455
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1456
1457
1458
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1459
        return engine_args
1460

1461
    def create_model_config(self) -> ModelConfig:
1462
1463
        # gguf file needs a specific model loader
        if is_gguf(self.model):
1464
1465
            self.quantization = self.load_format = "gguf"

1466
1467
1468
1469
1470
1471
1472
        if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
            logger.warning(
                "The global random seed is set to %d. Since "
                "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
                "affect the random state of the Python process that "
                "launched vLLM.",
                self.seed,
1473
1474
            )

1475
        return ModelConfig(
1476
            model=self.model,
1477
            model_weights=self.model_weights,
1478
            hf_config_path=self.hf_config_path,
1479
1480
            runner=self.runner,
            convert=self.convert,
1481
            tokenizer=self.tokenizer,  # type: ignore[arg-type]
1482
            tokenizer_mode=self.tokenizer_mode,
1483
            trust_remote_code=self.trust_remote_code,
1484
1485
            allowed_local_media_path=self.allowed_local_media_path,
            allowed_media_domains=self.allowed_media_domains,
1486
1487
1488
1489
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
1490
            hf_token=self.hf_token,
1491
            hf_overrides=self.hf_overrides,
1492
            tokenizer_revision=self.tokenizer_revision,
1493
1494
            max_model_len=self.max_model_len,
            quantization=self.quantization,
1495
            quantization_config=self.quantization_config,
1496
            allow_deprecated_quantization=self.allow_deprecated_quantization,
1497
            enforce_eager=self.enforce_eager,
1498
            enable_return_routed_experts=self.enable_return_routed_experts,
1499
            max_logprobs=self.max_logprobs,
1500
            logprobs_mode=self.logprobs_mode,
1501
            disable_sliding_window=self.disable_sliding_window,
1502
            disable_cascade_attn=self.disable_cascade_attn,
1503
            skip_tokenizer_init=self.skip_tokenizer_init,
1504
            enable_prompt_embeds=self.enable_prompt_embeds,
1505
            served_model_name=self.served_model_name,
1506
            language_model_only=self.language_model_only,
1507
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1508
            enable_mm_embeds=self.enable_mm_embeds,
1509
            interleave_mm_strings=self.interleave_mm_strings,
1510
            media_io_kwargs=self.media_io_kwargs,
1511
            skip_mm_profiling=self.skip_mm_profiling,
1512
            config_format=self.config_format,
1513
            mm_processor_kwargs=self.mm_processor_kwargs,
1514
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1515
            mm_processor_cache_type=self.mm_processor_cache_type,
1516
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1517
            mm_encoder_only=self.mm_encoder_only,
1518
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1519
            mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1520
            pooler_config=self.pooler_config,
1521
            generation_config=self.generation_config,
1522
            override_generation_config=self.override_generation_config,
1523
            enable_sleep_mode=self.enable_sleep_mode,
1524
            model_impl=self.model_impl,
1525
            override_attention_dtype=self.override_attention_dtype,
1526
            logits_processors=self.logits_processors,
1527
            video_pruning_rate=self.video_pruning_rate,
1528
            mm_tensor_ipc=self.mm_tensor_ipc,
1529
            io_processor_plugin=self.io_processor_plugin,
1530
            renderer_num_workers=self.renderer_num_workers,
1531
        )
1532

1533
    def validate_tensorizer_args(self):
1534
1535
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1536
1537
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1538
1539
1540
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1541

1542
    def create_load_config(self) -> LoadConfig:
1543
1544
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1545

1546
1547
1548
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1549
1550
                    self.model_loader_extra_config.to_serializable()
                )
1551
            self.model_loader_extra_config["tensorizer_config"] = {}
1552
1553
1554
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1555
            self.validate_tensorizer_args()
1556

1557
1558
1559
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1560
            safetensors_load_strategy=self.safetensors_load_strategy,
1561
1562
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1563
            use_tqdm_on_load=self.use_tqdm_on_load,
1564
            pt_load_map_location=self.pt_load_map_location,
1565
        )
1566

1567
1568
1569
1570
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
1571
    ) -> SpeculativeConfig | None:
1572
1573
1574
1575
1576
1577
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1578
        dictionary from the engine.
1579
1580
        """
        if self.speculative_config is None:
1581
            return None
1582

1583
1584
1585
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1586
1587
1588
1589
1590
1591
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
            }
        )
1592
        return SpeculativeConfig(**self.speculative_config)
1593

1594
1595
    def create_engine_config(
        self,
1596
        usage_context: UsageContext | None = None,
1597
        headless: bool = False,
1598
1599
1600
1601
    ) -> VllmConfig:
        """
        Create the VllmConfig.

1602
        NOTE: If VllmConfig is incompatible, we raise an error.
1603
        """
1604
        current_platform.pre_register_and_update()
1605

1606
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1607

1608
1609
        envs.validate_environ(self.fail_on_environ_validation)

1610
1611
        # Check if the model is a speculator and override model/tokenizer/config
        # BEFORE creating ModelConfig, so the config is created with the target model
1612
1613
1614
1615
        # Skip speculator detection for cloud storage models (eg: S3, GCS) since
        # HuggingFace cannot load configs directly from S3 URLs. S3 models can still
        # use speculators with explicit --speculative-config.
        if not is_cloud_storage(self.model):
1616
1617
1618
1619
1620
1621
1622
            (self.model, self.tokenizer, self.speculative_config) = (
                maybe_override_with_speculators(
                    model=self.model,
                    tokenizer=self.tokenizer,
                    revision=self.revision,
                    trust_remote_code=self.trust_remote_code,
                    vllm_speculative_config=self.speculative_config,
1623
                    hf_token=self.hf_token,
1624
1625
1626
                )
            )

1627
        model_config = self.create_model_config()
1628
        self.model = model_config.model
1629
        self.model_weights = model_config.model_weights
1630
1631
        self.tokenizer = model_config.tokenizer

1632
        self._check_feature_supported()
1633
        self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
1634
        self._set_default_reasoning_config_args()
1635
        sliding_window: int | None = None
1636
1637
1638
1639
1640
1641
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1642
1643
1644
1645
1646
        # Resolve "auto" kv_cache_dtype to actual value from model config
        resolved_cache_dtype = resolve_kv_cache_dtype_string(
            self.kv_cache_dtype, model_config
        )

1647
1648
1649
1650
        assert self.enable_prefix_caching is not None, (
            "enable_prefix_caching must be set by this point"
        )

1651
        cache_config = CacheConfig(
1652
            block_size=self.block_size,  # type: ignore[arg-type]
1653
            gpu_memory_utilization=self.gpu_memory_utilization,
1654
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1655
            cache_dtype=resolved_cache_dtype,  # type: ignore[arg-type]
1656
            is_attention_free=model_config.is_attention_free,
1657
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1658
            sliding_window=sliding_window,
1659
            enable_prefix_caching=self.enable_prefix_caching,
1660
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1661
            calculate_kv_scales=self.calculate_kv_scales,
1662
            kv_cache_dtype_skip_layers=self.kv_cache_dtype_skip_layers,
1663
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1664
1665
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1666
            mamba_block_size=self.mamba_block_size,
1667
            mamba_cache_mode=self.mamba_cache_mode,
1668
1669
            kv_offloading_size=self.kv_offloading_size,
            kv_offloading_backend=self.kv_offloading_backend,
1670
        )
1671

1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
        # TurboQuant: auto-skip first/last 2 layers (boundary protection).
        # These layers are most sensitive to quantization error.
        # Users can add extra layers via --kv-cache-dtype-skip-layers.
        if resolved_cache_dtype.startswith("turboquant_"):
            if model_config.is_hybrid:
                raise NotImplementedError(
                    "TurboQuant KV cache is not supported for hybrid "
                    "(attention + Mamba) models. Boundary layer protection "
                    "requires uniform attention layers."
                )
            from vllm.model_executor.layers.quantization.turboquant.config import (
                TurboQuantConfig,
            )

            num_layers = model_config.hf_text_config.num_hidden_layers
            boundary = TurboQuantConfig.get_boundary_skip_layers(num_layers)
            existing = set(cache_config.kv_cache_dtype_skip_layers)
            merged = sorted(existing | set(boundary), key=lambda x: int(x))
            cache_config.kv_cache_dtype_skip_layers = merged
            logger.info(
                "TQ: skipping layers %s for boundary protection (num_layers=%d)",
                merged,
                num_layers,
            )

1697
1698
1699
1700
1701
1702
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1703

1704
            ray_runtime_env = ray.get_runtime_context().runtime_env
1705
1706
1707
1708
1709
1710
1711
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1712

1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1724
        assert not headless or not self.data_parallel_hybrid_lb, (
1725
1726
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
        assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
            "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
        )
        assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
            "nnodes > 1 is only supported with data_parallel_backend=mp"
        )
        inferred_data_parallel_rank = 0
        if self.nnodes > 1:
            world_size = (
                self.data_parallel_size
                * self.pipeline_parallel_size
                * self.tensor_parallel_size
            )
            world_size_within_dp = (
                self.pipeline_parallel_size * self.tensor_parallel_size
            )
            local_world_size = world_size // self.nnodes
            assert world_size % self.nnodes == 0, (
                f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
            )
            assert self.node_rank < self.nnodes, (
                f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
            )
            inferred_data_parallel_rank = (
                self.node_rank * local_world_size
            ) // world_size_within_dp
            if self.data_parallel_size > 1 and self.data_parallel_external_lb:
                self.data_parallel_rank = inferred_data_parallel_rank
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d for external lb",
                    self.data_parallel_rank,
                    self.node_rank,
                )
            elif self.data_parallel_size_local is None:
                # Infer data parallel size local for internal dplb:
                self.data_parallel_size_local = max(
                    local_world_size // world_size_within_dp, 1
                )
        data_parallel_external_lb = (
            self.data_parallel_external_lb or self.data_parallel_rank is not None
        )
1768
        # Local DP rank = 1, use pure-external LB.
1769
        if data_parallel_external_lb:
1770
            assert self.data_parallel_rank is not None, (
1771
                "data_parallel_rank or node_rank must be specified if "
1772
1773
                "data_parallel_external_lb is enable."
            )
1774
            assert self.data_parallel_size_local in (1, None), (
1775
1776
                "data_parallel_size_local must be 1 or None when data_parallel_rank "
                "is set"
1777
            )
1778
            data_parallel_size_local = 1
1779
1780
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1781
1782
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1783
1784
1785
1786
1787
1788
1789

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
1790
1791
1792
1793
1794
                logger.warning(
                    "data_parallel_hybrid_lb is not eligible when "
                    "data_parallel_size_local = 1, autoswitch to "
                    "data_parallel_external_lb."
                )
1795
1796
1797
1798
1799
1800
1801
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

1802
1803
1804
1805
1806
1807
1808
1809
1810
            self.data_parallel_rank = (
                self.data_parallel_start_rank or inferred_data_parallel_rank
            )
            if self.nnodes > 1:
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d",
                    self.data_parallel_rank,
                    self.node_rank,
                )
1811
        else:
1812
            assert not self.data_parallel_hybrid_lb, (
1813
1814
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1815

1816
1817
1818
1819
1820
1821
1822
1823
1824
            if self.data_parallel_backend == "ray" and (
                envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
            ):
                # Data parallel size defaults to 1 if DP ranks are spanning
                # multiple nodes
                data_parallel_size_local = 1
            else:
                # Otherwise local DP size defaults to global DP size if not set
                data_parallel_size_local = self.data_parallel_size
1825
1826
1827

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1828
1829
1830
1831
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1832
1833
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1834
1835
1836
1837
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1838
1839
                    self.data_parallel_backend,
                )
1840
1841
1842
                data_parallel_address = (
                    self.master_addr or ParallelConfig.data_parallel_master_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1843
1844
        else:
            data_parallel_address = self.data_parallel_address
1845
1846
1847

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1848
        data_parallel_rpc_port = (
1849
            self.data_parallel_rpc_port
1850
1851
1852
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1853

1854
1855
1856
1857
        if self.tokens_only and not model_config.skip_tokenizer_init:
            model_config.skip_tokenizer_init = True
            logger.info("Skipping tokenizer initialization for tokens-only mode.")

1858
        parallel_config = ParallelConfig(
1859
1860
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1861
            prefill_context_parallel_size=self.prefill_context_parallel_size,
1862
            data_parallel_size=self.data_parallel_size,
1863
1864
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1865
            data_parallel_size_local=data_parallel_size_local,
1866
1867
1868
1869
            master_addr=self.master_addr,
            master_port=self.master_port,
            nnodes=self.nnodes,
            node_rank=self.node_rank,
1870
            distributed_timeout_seconds=self.distributed_timeout_seconds,
1871
1872
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1873
            data_parallel_backend=self.data_parallel_backend,
1874
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1875
            is_moe_model=model_config.is_moe,
1876
            enable_expert_parallel=self.enable_expert_parallel,
1877
            enable_ep_weight_filter=self.enable_ep_weight_filter,
1878
            all2all_backend=self.all2all_backend,
1879
            enable_elastic_ep=self.enable_elastic_ep,
1880
            enable_dbo=self.enable_dbo,
1881
            ubatch_size=self.ubatch_size,
1882
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1883
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1884
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1885
            enable_eplb=self.enable_eplb,
1886
            eplb_config=self.eplb_config,
1887
            expert_placement_strategy=self.expert_placement_strategy,
1888
1889
1890
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1891
            ray_runtime_env=ray_runtime_env,
1892
            placement_group=placement_group,
1893
1894
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1895
            worker_extension_cls=self.worker_extension_cls,
1896
            decode_context_parallel_size=self.decode_context_parallel_size,
1897
            dcp_comm_backend=self.dcp_comm_backend,
1898
            dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
1899
            cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
1900
1901
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1902
1903
1904
            numa_bind=self.numa_bind,
            numa_bind_nodes=self.numa_bind_nodes,
            numa_bind_cpus=self.numa_bind_cpus,
1905
        )
1906

1907
        speculative_config = self.create_speculative_config(
1908
1909
1910
1911
            target_model_config=model_config,
            target_parallel_config=parallel_config,
        )

1912
1913
1914
1915
1916
1917
        self._set_default_max_num_seqs_and_batched_tokens_args(
            usage_context,
            model_config,
            parallel_config,
        )

1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
        assert self.max_num_batched_tokens is not None, (
            "max_num_batched_tokens must be set by this point"
        )
        assert self.max_num_seqs is not None, "max_num_seqs must be set by this point"
        assert self.enable_chunked_prefill is not None, (
            "enable_chunked_prefill must be set by this point"
        )
        assert model_config.max_model_len is not None, (
            "max_model_len must be set by this point"
        )
1928
        scheduler_config = SchedulerConfig(
1929
            runner_type=model_config.runner_type,
1930
1931
1932
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1933
            enable_chunked_prefill=self.enable_chunked_prefill,
1934
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1935
            is_multimodal_model=model_config.is_multimodal_model,
1936
            is_encoder_decoder=model_config.is_encoder_decoder,
1937
            policy=self.scheduling_policy,
1938
            scheduler_cls=self.scheduler_cls,
1939
1940
1941
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1942
            scheduler_reserve_full_isl=self.scheduler_reserve_full_isl,
1943
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1944
            async_scheduling=self.async_scheduling,
1945
            stream_interval=self.stream_interval,
1946
        )
1947

1948
1949
1950
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_dtype=self.lora_dtype,
1961
                target_modules=self.lora_target_modules,
1962
                enable_tower_connector_lora=self.enable_tower_connector_lora,
1963
                specialize_active_lora=self.specialize_active_lora,
1964
1965
1966
1967
1968
1969
1970
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1971

1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
        if (
            lora_config is not None
            and speculative_config is not None
            and scheduler_config.max_num_batched_tokens
            < (
                scheduler_config.max_num_seqs
                * (speculative_config.num_speculative_tokens + 1)
            )
        ):
            raise ValueError(
                "Consider increasing max_num_batched_tokens or "
                "decreasing num_speculative_tokens"
            )

1986
1987
1988
1989
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1990
1991
1992
1993
1994
1995
1996
1997
        # Attention config overrides
        attention_config = copy.deepcopy(self.attention_config)
        if self.attention_backend is not None:
            if attention_config.backend is not None:
                raise ValueError(
                    "attention_backend and attention_config.backend "
                    "are mutually exclusive"
                )
1998
1999
2000
2001
            # Reuse the validator to handle "auto" and string-to-enum conversion
            attention_config.backend = AttentionConfig.validate_backend_before(
                self.attention_backend
            )
2002

2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
        # TurboQuant requires FlashAttention 2 — FA3 boundary layers assert
        # FlashAttentionImpl which fails with TurboQuantAttentionImpl.
        if resolved_cache_dtype.startswith("turboquant_") and (
            attention_config.flash_attn_version is None
            or attention_config.flash_attn_version >= 3
        ):
            logger.warning(
                "TurboQuant is not yet compatible with FlashAttention >= 3. "
                "Overriding flash_attn_version to 2. To silence this "
                "warning, pass --attention-config.flash_attn_version=2"
            )
            attention_config.flash_attn_version = 2

2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
        # Mamba config overrides
        mamba_config = copy.deepcopy(self.mamba_config)
        # Convert string to enum if needed (CLI parsing returns a string)
        if isinstance(self.mamba_backend, str):
            mamba_config.backend = MambaBackendEnum[self.mamba_backend.upper()]
        else:
            mamba_config.backend = self.mamba_backend
        if self.enable_mamba_cache_stochastic_rounding:
            mamba_config.enable_stochastic_rounding = (
                self.enable_mamba_cache_stochastic_rounding
            )
        if self.mamba_cache_philox_rounds:
            mamba_config.stochastic_rounding_philox_rounds = (
                self.mamba_cache_philox_rounds
            )

2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
        # Kernel config overrides
        kernel_config = copy.deepcopy(self.kernel_config)
        if self.enable_flashinfer_autotune is not None:
            if kernel_config.enable_flashinfer_autotune is not None:
                raise ValueError(
                    "enable_flashinfer_autotune and "
                    "kernel_config.enable_flashinfer_autotune "
                    "are mutually exclusive"
                )
            kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune
2042
2043
        if self.moe_backend != "auto":
            kernel_config.moe_backend = self.moe_backend
2044

2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
        # Transfer top-level ir_op_priority into KernelConfig.ir_op_priority
        for op_name, op_priority in asdict(self.ir_op_priority).items():
            # Empty means unset
            if not op_priority:
                continue

            # Priority cannot be set 2x for the same op
            if getattr(kernel_config.ir_op_priority, op_name):
                raise ValueError(
                    f"Op priority for {op_name} specified via both ir_op_priority "
                    f"and KernelConfig.ir_op_priority, only one allowed at a time."
                )

            # Set the attribute
            setattr(kernel_config.ir_op_priority, op_name, op_priority)

2061
        load_config = self.create_load_config()
2062

2063
2064
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
2065
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
2066

2067
2068
2069
2070
2071
        if self.reasoning_parser_plugin:
            self.structured_outputs_config.reasoning_parser_plugin = (
                self.reasoning_parser_plugin
            )

2072
        observability_config = ObservabilityConfig(
2073
            show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
2074
            otlp_traces_endpoint=self.otlp_traces_endpoint,
2075
            collect_detailed_traces=self.collect_detailed_traces,
2076
2077
            kv_cache_metrics=self.kv_cache_metrics,
            kv_cache_metrics_sample=self.kv_cache_metrics_sample,
2078
            cudagraph_metrics=self.cudagraph_metrics,
2079
            enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
2080
            enable_mfu_metrics=self.enable_mfu_metrics,
2081
            enable_mm_processor_stats=self.enable_mm_processor_stats,
2082
            enable_logging_iteration_details=self.enable_logging_iteration_details,
2083
        )
2084

2085
        # Compilation config overrides
2086
        compilation_config = copy.deepcopy(self.compilation_config)
2087
        if self.cudagraph_capture_sizes is not None:
2088
            if compilation_config.cudagraph_capture_sizes is not None:
2089
2090
2091
2092
                raise ValueError(
                    "cudagraph_capture_sizes and compilation_config."
                    "cudagraph_capture_sizes are mutually exclusive"
                )
2093
            compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes
2094
        if self.max_cudagraph_capture_size is not None:
2095
            if compilation_config.max_cudagraph_capture_size is not None:
2096
2097
2098
2099
                raise ValueError(
                    "max_cudagraph_capture_size and compilation_config."
                    "max_cudagraph_capture_size are mutually exclusive"
                )
2100
            compilation_config.max_cudagraph_capture_size = (
2101
2102
                self.max_cudagraph_capture_size
            )
2103
2104

        offload_config = OffloadConfig(
2105
            offload_backend=self.offload_backend,
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
            uva=UVAOffloadConfig(
                cpu_offload_gb=self.cpu_offload_gb,
                cpu_offload_params=self.cpu_offload_params,
            ),
            prefetch=PrefetchOffloadConfig(
                offload_group_size=self.offload_group_size,
                offload_num_in_group=self.offload_num_in_group,
                offload_prefetch_step=self.offload_prefetch_step,
                offload_params=self.offload_params,
            ),
        )

2118
2119
2120
        if self.gdn_prefill_backend is not None:
            self.additional_config["gdn_prefill_backend"] = self.gdn_prefill_backend

2121
        config = VllmConfig(
2122
2123
2124
2125
2126
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
2127
            load_config=load_config,
2128
            offload_config=offload_config,
2129
            attention_config=attention_config,
2130
            mamba_config=mamba_config,
2131
            kernel_config=kernel_config,
2132
2133
            lora_config=lora_config,
            speculative_config=speculative_config,
2134
            structured_outputs_config=self.structured_outputs_config,
2135
            observability_config=observability_config,
2136
            compilation_config=compilation_config,
2137
            kv_transfer_config=self.kv_transfer_config,
2138
            kv_events_config=self.kv_events_config,
2139
            ec_transfer_config=self.ec_transfer_config,
2140
            reasoning_config=self.reasoning_config,
2141
            profiler_config=self.profiler_config,
2142
            additional_config=self.additional_config,
2143
            optimization_level=self.optimization_level,
2144
            performance_mode=self.performance_mode,
2145
            weight_transfer_config=self.weight_transfer_config,
2146
            shutdown_timeout=self.shutdown_timeout,
2147
        )
2148

2149
2150
        return config

2151
    def _check_feature_supported(self):
2152
        """Raise an error if the feature is not supported."""
2153
        # No Concurrent Partial Prefills so far.
2154
2155
2156
2157
2158
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
2159
            _raise_unsupported_error(feature_name="Concurrent Partial Prefill")
2160

2161
        if self.pipeline_parallel_size > 1:
2162
2163
2164
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
2165
            if not supports_pp and self.distributed_executor_backend not in (
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
2176
                _raise_unsupported_error(feature_name=name)
2177

2178
2179
2180
2181
2182
2183
    @classmethod
    def get_batch_defaults(
        cls,
        world_size: int,
    ) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
        from vllm.usage.usage_lib import UsageContext
2184

2185
2186
        default_max_num_batched_tokens: dict[UsageContext | None, int]
        default_max_num_seqs: dict[UsageContext | None, int]
2187

2188
2189
        # When no user override, set the default values based on the usage
        # context.
2190
        # Use different default values for different hardware.
2191
2192
2193
2194
2195
2196
2197

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
2198
            device_memory = current_platform.get_device_total_memory()
2199
            device_name = current_platform.get_device_name().lower()
2200
2201
        except Exception:
            # This is only used to set default_max_num_batched_tokens
2202
            device_memory = 0
2203
            device_name = ""
2204

2205
2206
2207
2208
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
2209
            # For GPUs like H100 and MI300x, use larger default values.
2210
2211
2212
2213
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
2214
2215
2216
2217
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
2218
2219
2220
2221
2222
2223
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
2224
2225
2226
2227
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
2228

2229
2230
        # tpu specific default values.
        if current_platform.is_tpu():
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
            chip_name = current_platform.get_device_name()

            if chip_name == "V6E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 2048,
                    UsageContext.OPENAI_API_SERVER: 1024,
                }
            elif chip_name == "V5E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 1024,
                    UsageContext.OPENAI_API_SERVER: 512,
                }
            elif chip_name == "V5P":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 512,
                    UsageContext.OPENAI_API_SERVER: 256,
                }
2248

2249
2250
2251
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
2252
2253
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
2254
2255
            }
            default_max_num_seqs = {
2256
2257
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
2258
2259
            }

2260
2261
        return default_max_num_batched_tokens, default_max_num_seqs

2262
2263
    def _set_default_chunked_prefill_and_prefix_caching_args(
        self, model_config: ModelConfig
2264
    ) -> None:
2265
2266
        default_chunked_prefill = model_config.is_chunked_prefill_supported
        default_prefix_caching = model_config.is_prefix_caching_supported
2267
2268
2269
2270
2271
2272
2273
2274

        if self.enable_chunked_prefill is None:
            self.enable_chunked_prefill = default_chunked_prefill

            logger.debug(
                "%s chunked prefill by default",
                "Enabling" if default_chunked_prefill else "Disabling",
            )
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
        elif (
            model_config.runner_type == "generate"
            and not self.enable_chunked_prefill
            and default_chunked_prefill
        ):
            logger.warning_once(
                "This model does not officially support disabling chunked prefill. "
                "Disabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
            )
2285
2286
2287
2288
        elif (
            model_config.runner_type == "pooling"
            and self.enable_chunked_prefill
            and not default_chunked_prefill
2289
        ):
2290
            logger.warning_once(
2291
2292
2293
2294
2295
2296
2297
2298
                "This model does not officially support chunked prefill. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
            )

        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = default_prefix_caching

2299
            logger.debug(
2300
2301
2302
2303
2304
2305
2306
2307
                "%s prefix caching by default",
                "Enabling" if default_prefix_caching else "Disabling",
            )
        elif (
            model_config.runner_type == "pooling"
            and self.enable_prefix_caching
            and not default_prefix_caching
        ):
2308
            logger.warning_once(
2309
2310
2311
2312
2313
                "This model does not officially support prefix caching. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
            )

2314
        # Disable chunked prefill and prefix caching for:
2315
        # RISCV CPUs in V1
2316
2317
2318
2319
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.RISCV,
        ):
            logger.info(
2320
2321
                "Chunked prefill is not supported for"
                "RISC-V CPUs; "
2322
2323
2324
2325
                "disabling it for V1 backend."
            )
            self.enable_chunked_prefill = False
            logger.info(
2326
2327
                "Prefix caching is not supported for "
                "RISC-V CPUs; "
2328
2329
2330
2331
                "disabling it for V1 backend."
            )
            self.enable_prefix_caching = False

2332
2333
2334
2335
2336
2337
2338
    def _set_default_reasoning_config_args(self):
        if not self.reasoning_parser:
            return
        if self.reasoning_config is None:
            self.reasoning_config = ReasoningConfig()
        self.reasoning_config.reasoning_parser = self.reasoning_parser

2339
    def _set_default_max_num_seqs_and_batched_tokens_args(
2340
2341
2342
        self,
        usage_context: UsageContext | None,
        model_config: ModelConfig,
2343
        parallel_config: ParallelConfig,
2344
    ):
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
        world_size = self.pipeline_parallel_size * self.tensor_parallel_size
        (
            default_max_num_batched_tokens,
            default_max_num_seqs,
        ) = self.get_batch_defaults(world_size)

        orig_max_num_batched_tokens = self.max_num_batched_tokens
        orig_max_num_seqs = self.max_num_seqs

        if self.max_num_batched_tokens is None:
2355
2356
2357
2358
2359
2360
2361
2362
2363
            if parallel_config.use_batched_dp_moe:
                self.max_num_batched_tokens = (
                    SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP
                )
            else:
                self.max_num_batched_tokens = default_max_num_batched_tokens.get(
                    usage_context,
                    SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
                )
2364
2365
2366
2367
2368
2369
2370

        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
            )

2371
2372
2373
2374
2375
2376
2377
        # If throughput mode is set, double max_num_batched_tokens and max_num_seqs.
        if self.performance_mode == "throughput":
            if orig_max_num_batched_tokens is None:
                self.max_num_batched_tokens *= 2
            if orig_max_num_seqs is None:
                self.max_num_seqs *= 2

2378
        if orig_max_num_batched_tokens is None:
2379
2380
2381
            assert model_config.max_model_len is not None, (
                "max_model_len must be set by this point"
            )
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
            if not self.enable_chunked_prefill:
                # If max_model_len is too short, use the default for higher throughput.
                self.max_num_batched_tokens = max(
                    model_config.max_model_len,
                    self.max_num_batched_tokens,
                )

            # When using default settings,
            # Ensure max_num_batched_tokens does not exceed model limit.
            # Some models (e.g., Whisper) have embeddings tied to max length.
            self.max_num_batched_tokens = min(
                self.max_num_seqs * model_config.max_model_len,
2394
2395
                self.max_num_batched_tokens,
            )
2396

2397
2398
2399
2400
            logger.debug(
                "Defaulting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens,
                usage_context.value if usage_context else None,
2401
            )
2402

2403
2404
2405
2406
        if orig_max_num_seqs is None:
            assert self.max_num_batched_tokens is not None  # For type checking
            self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

2407
            logger.debug(
2408
                "Defaulting max_num_seqs to %d for %s usage context.",
2409
                self.max_num_seqs,
2410
                usage_context.value if usage_context else None,
2411
            )
2412

2413

2414
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
2415
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
2416
    """Arguments for asynchronous vLLM engine."""
2417

2418
2419
    enable_log_requests: bool = False

2420
    @staticmethod
2421
2422
2423
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
2424
        # Initialize plugin to update the parser, for example, The plugin may
2425
        # add a new kind of quantization method to --quantization argument or
2426
2427
        # a new device to --device argument.
        load_general_plugins()
2428
2429
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
2430
2431
2432
2433
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
2434
            help="Enable logging request information, dependent on log level:\n"
2435
2436
2437
            "- INFO: Request ID, parameters and LoRA request.\n"
            "- DEBUG: Prompt inputs (e.g: text, token IDs).\n"
            "You can set the minimum log level via `VLLM_LOGGING_LEVEL`.",
2438
        )
2439
        current_platform.pre_register_and_update(parser)
2440
        return parser
2441
2442


2443
2444
2445
2446
2447
2448
def _raise_unsupported_error(feature_name: str):
    msg = (
        f"{feature_name} is not supported. We recommend to "
        f"remove {feature_name} from your config."
    )
    raise NotImplementedError(msg)