arg_utils.py 100 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from collections.abc import Callable
11
from dataclasses import MISSING, asdict, dataclass, fields, is_dataclass
12
from itertools import permutations
13
from types import UnionType
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Literal,
19
    TypeAlias,
20
21
22
23
24
25
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
26

27
import huggingface_hub
28
import regex as re
29
import torch
30
from pydantic import TypeAdapter, ValidationError
31
from pydantic.fields import FieldInfo
32
from typing_extensions import TypeIs
33

34
import vllm.envs as envs
35
from vllm.config import (
36
    AttentionConfig,
37
38
39
40
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
41
    ECTransferConfig,
42
    EPLBConfig,
43
    KernelConfig,
44
45
46
47
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
48
    MambaConfig,
49
    ModelConfig,
50
    MultiModalConfig,
51
    ObservabilityConfig,
52
    OffloadConfig,
53
54
    ParallelConfig,
    PoolerConfig,
55
    PrefetchOffloadConfig,
56
    ProfilerConfig,
57
    ReasoningConfig,
58
59
60
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
61
    UVAOffloadConfig,
62
    VllmConfig,
63
    WeightTransferConfig,
64
65
    get_attr_docs,
)
66
67
68
from vllm.config.cache import (
    CacheDType,
    KVOffloadingBackend,
69
    MambaCacheMode,
70
71
72
    MambaDType,
    PrefixCachingHashAlgo,
)
73
from vllm.config.device import Device
74
from vllm.config.kernel import IrOpPriorityConfig, MoEBackend
75
from vllm.config.lora import MaxLoRARanks
76
from vllm.config.mamba import MambaBackendEnum
77
78
79
80
81
82
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
83
    TokenizerMode,
84
)
85
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MMTensorIPC
86
from vllm.config.observability import DetailedTraceModules
87
88
89
from vllm.config.parallel import (
    All2AllBackend,
    DataParallelBackend,
90
    DCPCommBackend,
91
92
93
    DistributedExecutorBackend,
    ExpertPlacementStrategy,
)
94
from vllm.config.scheduler import SchedulerPolicy
95
from vllm.config.utils import get_field
96
from vllm.config.vllm import OptimizationLevel, PerformanceMode
97
from vllm.logger import init_logger, suppress_logging
98
from vllm.platforms import CpuArchEnum, current_platform
99
from vllm.plugins import load_general_plugins
100
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
101
102
103
104
from vllm.transformers_utils.config import (
    is_interleaved,
    maybe_override_with_speculators,
)
105
from vllm.transformers_utils.gguf_utils import is_gguf
106
from vllm.transformers_utils.repo_utils import get_model_path
107
from vllm.transformers_utils.utils import is_cloud_storage
108
from vllm.utils.argparse_utils import FlexibleArgumentParser
109
from vllm.utils.mem_constants import GiB_bytes
110
from vllm.utils.network_utils import get_ip
111
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
112
from vllm.v1.attention.backends.registry import AttentionBackendEnum
113
from vllm.v1.sample.logits_processor import LogitsProcessor
114
from vllm.version import __version__ as VLLM_VERSION
115

116
if TYPE_CHECKING:
117
    from vllm.config.quantization import OnlineQuantizationConfigArgs
118
    from vllm.model_executor.layers.quantization import QuantizationMethods
119
    from vllm.model_executor.model_loader import LoadFormats
120
    from vllm.usage.usage_lib import UsageContext
121
    from vllm.v1.executor import Executor
122
else:
123
    Executor = Any
124
    QuantizationMethods = Any
125
    LoadFormats = Any
126
127
    UsageContext = Any

128

129
130
logger = init_logger(__name__)

131
132
# object is used to allow for special typing forms
T = TypeVar("T")
133
134
TypeHint: TypeAlias = type[Any] | object
TypeHintT: TypeAlias = type[T] | object
135

136

137
138
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
139
140
141
142
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
143
144
                f"Value {val} cannot be converted to {return_type}."
            ) from e
145

146
147
148
    return _parse_type


149
150
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]:
    def _optional_type(val: str) -> T | None:
151
152
153
154
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

155
    return _optional_type
156
157


158
def union_dict_and_str(val: str) -> str | dict[str, str] | None:
159
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
160
        return str(val)
161
    return optional_type(json.loads)(val)
162
163


164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


179
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
180
181
182
183
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
184
    type_hint = get_type(type_hints, Literal)
185
186
187
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
188
        raise ValueError(
189
            "All options must be of the same type. "
190
191
            f"Got {options} with types {[type(c) for c in options]}"
        )
192
193
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
194
195


196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]:
    type_hint = get_type(type_hints, type)
    types = get_args(type_hint)
    elem_type = types[0]

    # Handle Ellipsis
    assert all(t is elem_type for t in types if t is not Ellipsis), (
        f"All non-Ellipsis elements must be of the same type. Got {types}."
    )

    # Handle Union types
    if get_origin(elem_type) in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
        assert str in get_args(elem_type), (
            "If element can have multiple types, one must be 'str' "
            f"(i.e. 'list[int | str]'). Got {elem_type}."
        )
        elem_type = str

    return {
        "type": elem_type,
        "nargs": "+" if type is not tuple or Ellipsis in types else len(types),
    }


221
222
223
224
225
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


226
227
228
229
230
231
232
233
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
234
235
    elif origin in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
236
237
238
239
240
241
242
243
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


244
NEEDS_HELP = (
245
246
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
247
248
249
250
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


251
252
253
254
255
256
257
258
def _maybe_add_docs_url(cls: Any) -> str:
    """Generate API docs URL for a vllm config class."""
    if not cls.__module__.startswith("vllm.config"):
        return ""
    version = f"v{VLLM_VERSION}" if "dev" not in VLLM_VERSION else "latest"
    return f"\n\nAPI docs: https://docs.vllm.ai/en/{version}/api/vllm/config/#vllm.config.{cls.__name__}"


259
@functools.lru_cache(maxsize=30)
260
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
261
262
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
263
264
    kwargs = {}
    for field in fields(cls):
265
        # Get the set of possible types for the field
266
        type_hints: set[TypeHint] = get_type_hints(field.type)
267
268
269
270
271

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

272
        # Get the default value of the field
273
274
        if field.default is not MISSING:
            default = field.default
275
276
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
277
278
279
280
281
282
                if default.default_factory is None:
                    default = default.default
                else:
                    # VllmConfig's Fields have default_factory set to config classes.
                    # These could emit logs on init, which would be confusing.
                    with suppress_logging():
283
                        default = default.default_factory()  # type: ignore[call-arg]
284
        elif field.default_factory is not MISSING:
285
            default = field.default_factory()
286
287
288

        # Get the help text for the field
        name = field.name
289
        help = cls_docs.get(name, "").strip()
290
291
292
293
294
295
296
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
297
298
299
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
300
        if dataclass_cls is not None:
301
302
303
304
305
306
307
308

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
309
            kwargs[name]["help"] += _maybe_add_docs_url(dataclass_cls)
310
            kwargs[name]["help"] += f"\n\n{json_tip}"
311
        elif contains_type(type_hints, bool):
312
313
314
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
315
            kwargs[name].update(literal_to_kwargs(type_hints))
316
        elif contains_type(type_hints, tuple):
317
            kwargs[name].update(collection_to_kwargs(type_hints, tuple))
318
        elif contains_type(type_hints, list):
319
320
321
            kwargs[name].update(collection_to_kwargs(type_hints, list))
        elif contains_type(type_hints, set):
            kwargs[name].update(collection_to_kwargs(type_hints, set))
322
        elif contains_type(type_hints, int):
323
324
325
326
            if name == "max_model_len":
                kwargs[name]["type"] = human_readable_int_or_auto
                kwargs[name]["help"] += f"\n\n{human_readable_int_or_auto.__doc__}"
            elif name in ("max_num_batched_tokens", "kv_cache_memory_bytes"):
327
                kwargs[name]["type"] = human_readable_int
328
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
329
330
            else:
                kwargs[name]["type"] = int
331
332
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
333
334
335
336
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
337
            kwargs[name]["type"] = union_dict_and_str
338
        elif contains_type(type_hints, dict):
339
            kwargs[name]["type"] = parse_type(json.loads)
340
            kwargs[name]["help"] += f"\n\n{json_tip}"
341
342
343
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
344
345
            kwargs[name]["type"] = str
        else:
346
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
347

348
349
350
351
352
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

353
354
355
356
357
358
359
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
360
361


362
def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
363
364
    """Return argparse kwargs for the given Config dataclass.

365
366
367
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

368
369
370
371
372
373
374
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


375
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
376
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
377
    """Arguments for vLLM engine."""
378

379
    model: str = ModelConfig.model
380
    enable_return_routed_experts: bool = ModelConfig.enable_return_routed_experts
381
    model_weights: str = ModelConfig.model_weights
382
    served_model_name: str | list[str] | None = ModelConfig.served_model_name
383
    tokenizer: str | None = ModelConfig.tokenizer
384
    hf_config_path: str | None = ModelConfig.hf_config_path
385
386
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
387
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
388
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
389
    tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
390
    trust_remote_code: bool = ModelConfig.trust_remote_code
391
392
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
    allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
393
    download_dir: str | None = LoadConfig.download_dir
394
    safetensors_load_strategy: str | None = LoadConfig.safetensors_load_strategy
395
    load_format: str | LoadFormats = LoadConfig.load_format
396
397
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
398
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
399
    seed: int = ModelConfig.seed
400
    max_model_len: int = ModelConfig.max_model_len
401
402
403
404
405
406
    cudagraph_capture_sizes: list[int] | None = (
        CompilationConfig.cudagraph_capture_sizes
    )
    max_cudagraph_capture_size: int | None = get_field(
        CompilationConfig, "max_cudagraph_capture_size"
    )
407
    ir_op_priority: IrOpPriorityConfig = get_field(KernelConfig, "ir_op_priority")
408
409
410
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
411
    distributed_executor_backend: (
412
        str | DistributedExecutorBackend | type[Executor] | None
413
    ) = ParallelConfig.distributed_executor_backend
414
    # number of P/D disaggregation (or other disaggregation) workers
415
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
416
417
418
419
    master_addr: str = ParallelConfig.master_addr
    master_port: int = ParallelConfig.master_port
    nnodes: int = ParallelConfig.nnodes
    node_rank: int = ParallelConfig.node_rank
420
    distributed_timeout_seconds: int | None = ParallelConfig.distributed_timeout_seconds
421
422
423
    numa_bind: bool = ParallelConfig.numa_bind
    numa_bind_nodes: list[int] | None = ParallelConfig.numa_bind_nodes
    numa_bind_cpus: list[str] | None = ParallelConfig.numa_bind_cpus
424
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
425
    prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
426
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
427
    dcp_comm_backend: DCPCommBackend = ParallelConfig.dcp_comm_backend
428
    dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
429
    cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
430
    data_parallel_size: int = ParallelConfig.data_parallel_size
431
432
433
434
435
    data_parallel_rank: int | None = None
    data_parallel_start_rank: int | None = None
    data_parallel_size_local: int | None = None
    data_parallel_address: str | None = None
    data_parallel_rpc_port: int | None = None
436
    data_parallel_hybrid_lb: bool = False
437
    data_parallel_external_lb: bool = False
438
    data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend
439
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
440
    enable_ep_weight_filter: bool = ParallelConfig.enable_ep_weight_filter
441
    moe_backend: MoEBackend = KernelConfig.moe_backend
442
    all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
443
    enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
444
    enable_dbo: bool = ParallelConfig.enable_dbo
445
    ubatch_size: int = ParallelConfig.ubatch_size
446
447
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
448
    disable_nccl_for_dp_synchronization: bool | None = (
449
450
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
451
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
452
    enable_eplb: bool = ParallelConfig.enable_eplb
453
    expert_placement_strategy: ExpertPlacementStrategy = (
454
        ParallelConfig.expert_placement_strategy
455
    )
456
457
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
458
    max_parallel_loading_workers: int | None = (
459
460
        ParallelConfig.max_parallel_loading_workers
    )
461
    block_size: int | None = None
462
    enable_prefix_caching: bool | None = None
463
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
464
        CacheConfig.prefix_caching_hash_algo
465
    )
466
467
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
468
469
470
471
472
473
474
    offload_backend: str = OffloadConfig.offload_backend
    cpu_offload_gb: float = UVAOffloadConfig.cpu_offload_gb
    cpu_offload_params: set[str] = get_field(UVAOffloadConfig, "cpu_offload_params")
    offload_group_size: int = PrefetchOffloadConfig.offload_group_size
    offload_num_in_group: int = PrefetchOffloadConfig.offload_num_in_group
    offload_prefetch_step: int = PrefetchOffloadConfig.offload_prefetch_step
    offload_params: set[str] = get_field(PrefetchOffloadConfig, "offload_params")
475
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
476
    kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
477
    max_num_batched_tokens: int | None = None
478
479
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
480
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
481
    max_num_seqs: int | None = None
482
    max_logprobs: int = ModelConfig.max_logprobs
483
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
484
    disable_log_stats: bool = False
485
    aggregate_engine_logging: bool = False
486
487
488
    revision: str | None = ModelConfig.revision
    code_revision: str | None = ModelConfig.code_revision
    hf_token: bool | str | None = ModelConfig.hf_token
489
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
490
    tokenizer_revision: str | None = ModelConfig.tokenizer_revision
491
    quantization: QuantizationMethods | str | None = ModelConfig.quantization
492
    quantization_config: "dict[str, Any] | OnlineQuantizationConfigArgs | None" = None
493
    allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
494
    enforce_eager: bool = ModelConfig.enforce_eager
495
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
496
    language_model_only: bool = MultiModalConfig.language_model_only
497
    limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
498
499
        MultiModalConfig, "limit_per_prompt"
    )
500
    enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
501
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
502
503
504
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
505
    mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
506
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
507
    mm_processor_cache_type: MMCacheType | None = (
508
        MultiModalConfig.mm_processor_cache_type
509
510
    )
    mm_shm_cache_max_object_size_mb: int = (
511
        MultiModalConfig.mm_shm_cache_max_object_size_mb
512
    )
513
    mm_encoder_only: bool = MultiModalConfig.mm_encoder_only
514
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
515
    mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
516
517
        MultiModalConfig.mm_encoder_attn_backend
    )
518
    io_processor_plugin: str | None = None
519
    renderer_num_workers: int = 1
520
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
521
    video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
522
    mm_tensor_ipc: MMTensorIPC = MultiModalConfig.mm_tensor_ipc
523
    # LoRA fields
524
    enable_lora: bool = False
525
    max_loras: int = LoRAConfig.max_loras
526
    max_lora_rank: MaxLoRARanks = LoRAConfig.max_lora_rank
527
    default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras
528
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
529
530
    max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
    lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
531
    lora_target_modules: list[str] | None = LoRAConfig.target_modules
532
    enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
533
    specialize_active_lora: bool = LoRAConfig.specialize_active_lora
534

535
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
536
    num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
537
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
538
    ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
539

540
    enable_chunked_prefill: bool | None = None
541
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
542

543
544
    scheduler_reserve_full_isl: bool = SchedulerConfig.scheduler_reserve_full_isl

545
    disable_hybrid_kv_cache_manager: bool | None = (
546
547
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
548

549
    structured_outputs_config: StructuredOutputsConfig = get_field(
550
551
        VllmConfig, "structured_outputs_config"
    )
552
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
553
    reasoning_parser_plugin: str | None = None
554

555
    speculative_config: dict[str, Any] | None = None
556

557
    show_hidden_metrics_for_version: str | None = (
558
        ObservabilityConfig.show_hidden_metrics_for_version
559
    )
560
561
    otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: list[DetailedTraceModules] | None = (
562
        ObservabilityConfig.collect_detailed_traces
563
    )
564
565
566
567
    kv_cache_metrics: bool = ObservabilityConfig.kv_cache_metrics
    kv_cache_metrics_sample: float = get_field(
        ObservabilityConfig, "kv_cache_metrics_sample"
    )
568
    cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
569
570
571
    enable_layerwise_nvtx_tracing: bool = (
        ObservabilityConfig.enable_layerwise_nvtx_tracing
    )
572
    enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
573
574
575
    enable_logging_iteration_details: bool = (
        ObservabilityConfig.enable_logging_iteration_details
    )
576
    enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
577
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
578
    scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
579

580
    pooler_config: PoolerConfig | None = ModelConfig.pooler_config
581
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
582
    attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
583
    mamba_config: MambaConfig = get_field(VllmConfig, "mamba_config")
584
585
586
587
    kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
    enable_flashinfer_autotune: bool = get_field(
        KernelConfig, "enable_flashinfer_autotune"
    )
588
589
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
590

591
592
    profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")

593
594
    kv_transfer_config: KVTransferConfig | None = None
    kv_events_config: KVEventsConfig | None = None
595

596
    ec_transfer_config: ECTransferConfig | None = None
597
    reasoning_config: ReasoningConfig = get_field(VllmConfig, "reasoning_config")
598

599
600
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
601
602
603
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
604
    model_impl: str = ModelConfig.model_impl
605
    override_attention_dtype: str | None = ModelConfig.override_attention_dtype
606
    attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
607

608
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
609
610
611
    kv_cache_dtype_skip_layers: list[str] = get_field(
        CacheConfig, "kv_cache_dtype_skip_layers"
    )
612
613
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
614
    mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
615
    mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
616
617

    mamba_backend: MambaBackendEnum = MambaBackendEnum.TRITON
618
    enable_mamba_cache_stochastic_rounding: bool = (
619
        MambaConfig.enable_stochastic_rounding
620
    )
621
    mamba_cache_philox_rounds: int = MambaConfig.stochastic_rounding_philox_rounds
622

623
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
624

625
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
626
    pt_load_map_location: str | dict[str, str] = LoadConfig.pt_load_map_location
627

628
    logits_processors: list[str | type[LogitsProcessor]] | None = (
629
630
        ModelConfig.logits_processors
    )
631
632
    """Custom logitproc types"""

633
    async_scheduling: bool | None = SchedulerConfig.async_scheduling
634

635
636
    stream_interval: int = SchedulerConfig.stream_interval

637
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
638
    optimization_level: OptimizationLevel = VllmConfig.optimization_level
639
    performance_mode: PerformanceMode = VllmConfig.performance_mode
640

641
    kv_offloading_size: float | None = CacheConfig.kv_offloading_size
642
    kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
643
    tokens_only: bool = False
644

645
646
    shutdown_timeout: int = 0

647
648
649
650
    weight_transfer_config: WeightTransferConfig | None = get_field(
        VllmConfig,
        "weight_transfer_config",
    )
651

652
    fail_on_environ_validation: bool = False
653
    gdn_prefill_backend: Literal["flashinfer", "triton"] | None = None
654

655
    def __post_init__(self):
656
657
658
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
659
        if isinstance(self.compilation_config, dict):
660
            self.compilation_config = CompilationConfig(**self.compilation_config)
661
662
        if isinstance(self.attention_config, dict):
            self.attention_config = AttentionConfig(**self.attention_config)
663
664
        if isinstance(self.mamba_config, dict):
            self.mamba_config = MambaConfig(**self.mamba_config)
665
666
        if isinstance(self.kernel_config, dict):
            self.kernel_config = KernelConfig(**self.kernel_config)
667
        if isinstance(self.eplb_config, dict):
668
            self.eplb_config = EPLBConfig(**self.eplb_config)
669
670
671
672
        if isinstance(self.weight_transfer_config, dict):
            self.weight_transfer_config = WeightTransferConfig(
                **self.weight_transfer_config
            )
673
674
675
        if isinstance(self.ir_op_priority, dict):
            self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority)

676
677
678
679
680
681
        from vllm.config.quantization import resolve_online_quant_config

        self.quantization_config = resolve_online_quant_config(
            self.quantization, self.quantization_config
        )

682
        # Setup plugins
683
        from vllm.plugins import load_general_plugins
684

685
        load_general_plugins()
686
        # when use hf offline,replace model and tokenizer id to local model path
687
688
689
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
            if model_id is not self.model:
                logger.info(
                    "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                    model_id,
                    self.model,
                )
            if self.tokenizer is not None:
                tokenizer_id = self.tokenizer
                self.tokenizer = get_model_path(self.tokenizer, self.tokenizer_revision)
                if tokenizer_id is not self.tokenizer:
                    logger.info(
                        "HF_HUB_OFFLINE is True, replace tokenizer_id [%s] "
                        "to tokenizer_path [%s]",
                        tokenizer_id,
                        self.tokenizer,
                    )
706
707

    @staticmethod
708
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
709
        """Shared CLI arguments for vLLM engine."""
710

711
        # Model arguments
712
713
714
715
716
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
717
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
718
            model_group.add_argument("--model", **model_kwargs["model"])
719
720
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
721
722
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
723
724
725
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
726
727
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
728
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
729
730
731
732
733
734
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
735
        model_group.add_argument("--revision", **model_kwargs["revision"])
736
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
737
738
739
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
740
741
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
742
743
744
745
        model_group.add_argument(
            "--allow-deprecated-quantization",
            **model_kwargs["allow_deprecated_quantization"],
        )
746
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
747
748
749
750
        model_group.add_argument(
            "--enable-return-routed-experts",
            **model_kwargs["enable_return_routed_experts"],
        )
751
752
753
754
755
756
757
758
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
759
760
761
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
762
763
764
765
766
767
768
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
769
770
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
790
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
791
792
793
794
795
796
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
797
798
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
799
        )
800
801
802
803
        model_group.add_argument(
            "--renderer-num-workers",
            **model_kwargs["renderer_num_workers"],
        )
804

805
806
807
808
809
810
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
811
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
812
813
814
815
816
817
818
819
820
821
822
823
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
824

825
826
827
828
829
830
831
832
833
834
        # Attention arguments
        attention_kwargs = get_kwargs(AttentionConfig)
        attention_group = parser.add_argument_group(
            title="AttentionConfig",
            description=AttentionConfig.__doc__,
        )
        attention_group.add_argument(
            "--attention-backend", **attention_kwargs["backend"]
        )

835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
        # Mamba arguments
        mamba_kwargs = get_kwargs(MambaConfig)
        mamba_group = parser.add_argument_group(
            title="MambaConfig",
            description=MambaConfig.__doc__,
        )
        mamba_group.add_argument("--mamba-backend", **mamba_kwargs["backend"])
        mamba_group.add_argument(
            "--enable-mamba-cache-stochastic-rounding",
            **mamba_kwargs["enable_stochastic_rounding"],
        )
        mamba_group.add_argument(
            "--mamba-cache-philox-rounds",
            **mamba_kwargs["stochastic_rounding_philox_rounds"],
        )

851
852
853
854
855
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
856
        )
857
        structured_outputs_group.add_argument(
858
            "--reasoning-parser",
859
            # Choices need to be validated after parsing to include plugins
860
861
            **structured_outputs_kwargs["reasoning_parser"],
        )
862
863
864
865
        structured_outputs_group.add_argument(
            "--reasoning-parser-plugin",
            **structured_outputs_kwargs["reasoning_parser_plugin"],
        )
866

867
        # Parallel arguments
868
869
870
871
872
873
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
874
            "--distributed-executor-backend",
875
876
            **parallel_kwargs["distributed_executor_backend"],
        )
877
        parallel_group.add_argument(
878
879
880
881
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
882
883
884
885
        parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
        parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
        parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
        parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
886
887
888
889
        parallel_group.add_argument(
            "--distributed-timeout-seconds",
            **parallel_kwargs["distributed_timeout_seconds"],
        )
890
891
892
893
894
895
896
        parallel_group.add_argument("--numa-bind", **parallel_kwargs["numa_bind"])
        parallel_group.add_argument(
            "--numa-bind-nodes", **parallel_kwargs["numa_bind_nodes"]
        )
        parallel_group.add_argument(
            "--numa-bind-cpus", **parallel_kwargs["numa_bind_cpus"]
        )
897
        parallel_group.add_argument(
898
899
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
900
        parallel_group.add_argument(
901
902
903
904
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
905
906
907
908
        parallel_group.add_argument(
            "--dcp-comm-backend",
            **parallel_kwargs["dcp_comm_backend"],
        )
909
910
911
912
        parallel_group.add_argument(
            "--dcp-kv-cache-interleave-size",
            **parallel_kwargs["dcp_kv_cache_interleave_size"],
        )
913
914
915
916
917
918
919
920
921
        parallel_group.add_argument(
            "--cp-kv-cache-interleave-size",
            **parallel_kwargs["cp_kv_cache_interleave_size"],
        )
        parallel_group.add_argument(
            "--prefill-context-parallel-size",
            "-pcp",
            **parallel_kwargs["prefill_context_parallel_size"],
        )
922
923
924
925
926
927
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
928
            type=int,
929
930
931
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
932
        parallel_group.add_argument(
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
963
        parallel_group.add_argument(
964
965
966
967
968
969
970
971
            "--data-parallel-hybrid-lb",
            "-dph",
            **parallel_kwargs["data_parallel_hybrid_lb"],
        )
        parallel_group.add_argument(
            "--data-parallel-external-lb",
            "-dpe",
            **parallel_kwargs["data_parallel_external_lb"],
972
973
        )
        parallel_group.add_argument(
974
975
976
            "--enable-expert-parallel",
            "-ep",
            **parallel_kwargs["enable_expert_parallel"],
977
        )
978
979
980
981
        parallel_group.add_argument(
            "--enable-ep-weight-filter",
            **parallel_kwargs["enable_ep_weight_filter"],
        )
982
983
984
        parallel_group.add_argument(
            "--all2all-backend", **parallel_kwargs["all2all_backend"]
        )
985
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
986
987
988
989
        parallel_group.add_argument(
            "--ubatch-size",
            **parallel_kwargs["ubatch_size"],
        )
990
991
992
        parallel_group.add_argument(
            "--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"]
        )
993
994
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
995
996
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
997
998
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
999
1000
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
1001
1002
1003
1004
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
1005
1006
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
1007
1008
        parallel_group.add_argument(
            "--expert-placement-strategy",
1009
1010
            **parallel_kwargs["expert_placement_strategy"],
        )
1011

1012
        parallel_group.add_argument(
1013
            "--max-parallel-loading-workers",
1014
1015
            **parallel_kwargs["max_parallel_loading_workers"],
        )
1016
        parallel_group.add_argument(
1017
1018
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
1019
        parallel_group.add_argument(
1020
            "--disable-custom-all-reduce",
1021
1022
1023
1024
1025
1026
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
1027

1028
1029
1030
1031
1032
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
1033
        )
1034
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
1046
1047
1048
1049
1050
            "--enable-prefix-caching",
            **{
                **cache_kwargs["enable_prefix_caching"],
                "default": None,
            },
1051
1052
1053
1054
1055
1056
1057
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
1058
1059
1060
        cache_group.add_argument(
            "--kv-cache-dtype-skip-layers", **cache_kwargs["kv_cache_dtype_skip_layers"]
        )
1061
1062
1063
1064
1065
1066
1067
1068
1069
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
1070
1071
1072
        cache_group.add_argument(
            "--mamba-block-size", **cache_kwargs["mamba_block_size"]
        )
1073
1074
1075
        cache_group.add_argument(
            "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
        )
1076
1077
1078
1079
1080
1081
        cache_group.add_argument(
            "--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
        )
        cache_group.add_argument(
            "--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
        )
1082

1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
        # Model weight offload related configs
        offload_kwargs = get_kwargs(OffloadConfig)
        uva_kwargs = get_kwargs(UVAOffloadConfig)
        prefetch_kwargs = get_kwargs(PrefetchOffloadConfig)
        offload_group = parser.add_argument_group(
            title="OffloadConfig",
            description=OffloadConfig.__doc__,
        )
        offload_group.add_argument(
            "--offload-backend", **offload_kwargs["offload_backend"]
        )
        offload_group.add_argument("--cpu-offload-gb", **uva_kwargs["cpu_offload_gb"])
        offload_group.add_argument(
            "--cpu-offload-params", **uva_kwargs["cpu_offload_params"]
        )
        offload_group.add_argument(
            "--offload-group-size",
            **prefetch_kwargs["offload_group_size"],
        )
        offload_group.add_argument(
            "--offload-num-in-group",
            **prefetch_kwargs["offload_num_in_group"],
        )
        offload_group.add_argument(
            "--offload-prefetch-step",
            **prefetch_kwargs["offload_prefetch_step"],
        )
        offload_group.add_argument(
            "--offload-params", **prefetch_kwargs["offload_params"]
        )

1114
        # Multimodal related configs
1115
1116
1117
1118
1119
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
1120
1121
1122
        multimodal_group.add_argument(
            "--language-model-only", **multimodal_kwargs["language_model_only"]
        )
1123
        multimodal_group.add_argument(
1124
1125
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
1126
1127
1128
        multimodal_group.add_argument(
            "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
        )
1129
1130
1131
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
1132
1133
1134
1135
1136
1137
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
1138
        multimodal_group.add_argument(
1139
1140
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
1141
1142
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
1143
1144
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
1145
1146
1147
        multimodal_group.add_argument(
            "--mm-encoder-only", **multimodal_kwargs["mm_encoder_only"]
        )
1148
        multimodal_group.add_argument(
1149
1150
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
1151
1152
1153
1154
        multimodal_group.add_argument(
            "--mm-encoder-attn-backend",
            **multimodal_kwargs["mm_encoder_attn_backend"],
        )
1155
1156
1157
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
1158
        multimodal_group.add_argument(
1159
1160
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
1161

1162
        multimodal_group.add_argument(
1163
1164
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
1165
1166
1167
        multimodal_group.add_argument(
            "--mm-tensor-ipc", **multimodal_kwargs["mm_tensor_ipc"]
        )
1168

1169
        # LoRA related configs
1170
1171
1172
1173
1174
1175
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
1176
            "--enable-lora",
1177
            action=argparse.BooleanOptionalAction,
1178
1179
            help="If True, enable handling of LoRA adapters.",
        )
1180
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
1181
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
1182
        lora_group.add_argument(
1183
            "--lora-dtype",
1184
1185
            **lora_kwargs["lora_dtype"],
        )
1186
1187
1188
1189
        lora_group.add_argument(
            "--enable-tower-connector-lora",
            **lora_kwargs["enable_tower_connector_lora"],
        )
1190
1191
1192
1193
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
1194
1195
1196
        lora_group.add_argument(
            "--lora-target-modules", **lora_kwargs["target_modules"]
        )
1197
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
1198
1199
1200
        lora_group.add_argument(
            "--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
        )
1201

1202
1203
1204
1205
1206
1207
1208
1209
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
1210
1211
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
1212
        observability_group.add_argument(
1213
1214
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
1215
1216
1217
1218
1219
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
1220
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
1221
1222
1223
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
1224
1225
            **observability_kwargs["collect_detailed_traces"],
        )
1226
1227
1228
1229
1230
1231
1232
        observability_group.add_argument(
            "--kv-cache-metrics", **observability_kwargs["kv_cache_metrics"]
        )
        observability_group.add_argument(
            "--kv-cache-metrics-sample",
            **observability_kwargs["kv_cache_metrics_sample"],
        )
1233
1234
1235
1236
        observability_group.add_argument(
            "--cudagraph-metrics",
            **observability_kwargs["cudagraph_metrics"],
        )
1237
1238
1239
1240
        observability_group.add_argument(
            "--enable-layerwise-nvtx-tracing",
            **observability_kwargs["enable_layerwise_nvtx_tracing"],
        )
1241
1242
1243
1244
        observability_group.add_argument(
            "--enable-mfu-metrics",
            **observability_kwargs["enable_mfu_metrics"],
        )
1245
1246
1247
1248
        observability_group.add_argument(
            "--enable-logging-iteration-details",
            **observability_kwargs["enable_logging_iteration_details"],
        )
1249

1250
1251
1252
1253
1254
1255
1256
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
1257
1258
1259
1260
1261
            "--max-num-batched-tokens",
            **{
                **scheduler_kwargs["max_num_batched_tokens"],
                "default": None,
            },
1262
        )
1263
        scheduler_group.add_argument(
1264
1265
1266
1267
1268
            "--max-num-seqs",
            **{
                **scheduler_kwargs["max_num_seqs"],
                "default": None,
            },
1269
1270
1271
1272
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
1273
1274
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
1275
1276
            **scheduler_kwargs["max_long_partial_prefills"],
        )
1277
1278
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
1279
1280
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
1281
1282
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
1283
        scheduler_group.add_argument(
1284
1285
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
1286
        scheduler_group.add_argument(
1287
1288
1289
1290
1291
            "--enable-chunked-prefill",
            **{
                **scheduler_kwargs["enable_chunked_prefill"],
                "default": None,
            },
1292
1293
1294
1295
1296
1297
1298
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1299
1300
1301
1302
        scheduler_group.add_argument(
            "--scheduler-reserve-full-isl",
            **scheduler_kwargs["scheduler_reserve_full_isl"],
        )
1303
1304
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1305
1306
1307
1308
1309
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1310
1311
1312
        scheduler_group.add_argument(
            "--stream-interval", **scheduler_kwargs["stream_interval"]
        )
1313

1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
        # Compilation arguments
        compilation_kwargs = get_kwargs(CompilationConfig)
        compilation_group = parser.add_argument_group(
            title="CompilationConfig",
            description=CompilationConfig.__doc__,
        )
        compilation_group.add_argument(
            "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
        )
        compilation_group.add_argument(
            "--max-cudagraph-capture-size",
            **compilation_kwargs["max_cudagraph_capture_size"],
        )

1328
1329
1330
1331
1332
1333
        # Kernel arguments
        kernel_kwargs = get_kwargs(KernelConfig)
        kernel_group = parser.add_argument_group(
            title="KernelConfig",
            description=KernelConfig.__doc__,
        )
1334
        kernel_group.add_argument("--ir-op-priority", **kernel_kwargs["ir_op_priority"])
1335
1336
1337
1338
        kernel_group.add_argument(
            "--enable-flashinfer-autotune",
            **kernel_kwargs["enable_flashinfer_autotune"],
        )
1339
1340
1341
        moe_backend_kwargs = kernel_kwargs["moe_backend"]
        moe_backend_kwargs["type"] = lambda s: s.lower().replace("-", "_")
        kernel_group.add_argument("--moe-backend", **moe_backend_kwargs)
1342

1343
        # vLLM arguments
1344
        vllm_kwargs = get_kwargs(VllmConfig)
1345
1346
1347
1348
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1349
1350
1351
1352
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1353
        vllm_group.add_argument(
1354
            "--speculative-config", "-sc", **vllm_kwargs["speculative_config"]
1355
1356
1357
1358
1359
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
1360
1361
1362
        vllm_group.add_argument(
            "--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
        )
1363
        vllm_group.add_argument(
1364
            "--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
1365
        )
1366
1367
1368
        vllm_group.add_argument(
            "--attention-config", "-ac", **vllm_kwargs["attention_config"]
        )
1369
        vllm_group.add_argument("--reasoning-config", **vllm_kwargs["reasoning_config"])
1370
        vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
1371
1372
1373
1374
1375
1376
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1377
        vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
1378
1379
1380
        vllm_group.add_argument(
            "--optimization-level", **vllm_kwargs["optimization_level"]
        )
1381
        vllm_group.add_argument("--performance-mode", **vllm_kwargs["performance_mode"])
1382
1383
1384
        vllm_group.add_argument(
            "--weight-transfer-config", **vllm_kwargs["weight_transfer_config"]
        )
1385

1386
        # Other arguments
1387
1388
1389
1390
1391
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1392

1393
1394
1395
1396
1397
1398
        parser.add_argument(
            "--aggregate-engine-logging",
            action="store_true",
            help="Log aggregate rather than per-engine statistics "
            "when using data parallelism.",
        )
1399
1400
1401
1402
1403
1404
1405
1406

        parser.add_argument(
            "--fail-on-environ-validation",
            help="If set, the engine will raise an error if "
            "environment validation fails.",
            default=False,
            action=argparse.BooleanOptionalAction,
        )
1407
1408
1409
1410
1411
1412
1413
1414

        parser.add_argument(
            "--shutdown-timeout",
            type=int,
            default=0,
            help="Shutdown timeout in seconds. 0 = abort, >0 = wait.",
        )

1415
1416
1417
1418
1419
1420
1421
        parser.add_argument(
            "--gdn-prefill-backend",
            dest="gdn_prefill_backend",
            choices=["flashinfer", "triton"],
            default=None,
            help="Select GDN prefill backend.",
        )
1422
        return parser
1423
1424

    @classmethod
1425
    def from_cli_args(cls, args: argparse.Namespace):
1426
1427
1428
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1429
1430
1431
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1432
        return engine_args
1433

1434
    def create_model_config(self) -> ModelConfig:
1435
1436
        # gguf file needs a specific model loader
        if is_gguf(self.model):
1437
1438
            self.quantization = self.load_format = "gguf"

1439
1440
1441
1442
1443
1444
1445
        if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
            logger.warning(
                "The global random seed is set to %d. Since "
                "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
                "affect the random state of the Python process that "
                "launched vLLM.",
                self.seed,
1446
1447
            )

1448
        return ModelConfig(
1449
            model=self.model,
1450
            model_weights=self.model_weights,
1451
            hf_config_path=self.hf_config_path,
1452
1453
            runner=self.runner,
            convert=self.convert,
1454
            tokenizer=self.tokenizer,  # type: ignore[arg-type]
1455
            tokenizer_mode=self.tokenizer_mode,
1456
            trust_remote_code=self.trust_remote_code,
1457
1458
            allowed_local_media_path=self.allowed_local_media_path,
            allowed_media_domains=self.allowed_media_domains,
1459
1460
1461
1462
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
1463
            hf_token=self.hf_token,
1464
            hf_overrides=self.hf_overrides,
1465
            tokenizer_revision=self.tokenizer_revision,
1466
1467
            max_model_len=self.max_model_len,
            quantization=self.quantization,
1468
            quantization_config=self.quantization_config,
1469
            allow_deprecated_quantization=self.allow_deprecated_quantization,
1470
            enforce_eager=self.enforce_eager,
1471
            enable_return_routed_experts=self.enable_return_routed_experts,
1472
            max_logprobs=self.max_logprobs,
1473
            logprobs_mode=self.logprobs_mode,
1474
            disable_sliding_window=self.disable_sliding_window,
1475
            disable_cascade_attn=self.disable_cascade_attn,
1476
            skip_tokenizer_init=self.skip_tokenizer_init,
1477
            enable_prompt_embeds=self.enable_prompt_embeds,
1478
            served_model_name=self.served_model_name,
1479
            language_model_only=self.language_model_only,
1480
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1481
            enable_mm_embeds=self.enable_mm_embeds,
1482
            interleave_mm_strings=self.interleave_mm_strings,
1483
            media_io_kwargs=self.media_io_kwargs,
1484
            skip_mm_profiling=self.skip_mm_profiling,
1485
            config_format=self.config_format,
1486
            mm_processor_kwargs=self.mm_processor_kwargs,
1487
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1488
            mm_processor_cache_type=self.mm_processor_cache_type,
1489
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1490
            mm_encoder_only=self.mm_encoder_only,
1491
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1492
            mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1493
            pooler_config=self.pooler_config,
1494
            generation_config=self.generation_config,
1495
            override_generation_config=self.override_generation_config,
1496
            enable_sleep_mode=self.enable_sleep_mode,
1497
            model_impl=self.model_impl,
1498
            override_attention_dtype=self.override_attention_dtype,
1499
            logits_processors=self.logits_processors,
1500
            video_pruning_rate=self.video_pruning_rate,
1501
            mm_tensor_ipc=self.mm_tensor_ipc,
1502
            io_processor_plugin=self.io_processor_plugin,
1503
            renderer_num_workers=self.renderer_num_workers,
1504
        )
1505

1506
    def validate_tensorizer_args(self):
1507
1508
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1509
1510
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1511
1512
1513
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1514

1515
    def create_load_config(self) -> LoadConfig:
1516
1517
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1518

1519
1520
1521
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1522
1523
                    self.model_loader_extra_config.to_serializable()
                )
1524
            self.model_loader_extra_config["tensorizer_config"] = {}
1525
1526
1527
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1528
            self.validate_tensorizer_args()
1529

1530
1531
1532
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1533
            safetensors_load_strategy=self.safetensors_load_strategy,
1534
1535
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1536
            use_tqdm_on_load=self.use_tqdm_on_load,
1537
            pt_load_map_location=self.pt_load_map_location,
1538
        )
1539

1540
1541
1542
1543
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
1544
    ) -> SpeculativeConfig | None:
1545
1546
1547
1548
1549
1550
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1551
        dictionary from the engine.
1552
1553
        """
        if self.speculative_config is None:
1554
            return None
1555

1556
1557
1558
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1559
1560
1561
1562
1563
1564
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
            }
        )
1565
        return SpeculativeConfig(**self.speculative_config)
1566

1567
1568
    def create_engine_config(
        self,
1569
        usage_context: UsageContext | None = None,
1570
        headless: bool = False,
1571
1572
1573
1574
    ) -> VllmConfig:
        """
        Create the VllmConfig.

1575
        NOTE: If VllmConfig is incompatible, we raise an error.
1576
        """
1577
        current_platform.pre_register_and_update()
1578

1579
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1580

1581
1582
        envs.validate_environ(self.fail_on_environ_validation)

1583
1584
        # Check if the model is a speculator and override model/tokenizer/config
        # BEFORE creating ModelConfig, so the config is created with the target model
1585
1586
1587
1588
        # Skip speculator detection for cloud storage models (eg: S3, GCS) since
        # HuggingFace cannot load configs directly from S3 URLs. S3 models can still
        # use speculators with explicit --speculative-config.
        if not is_cloud_storage(self.model):
1589
1590
1591
1592
1593
1594
1595
            (self.model, self.tokenizer, self.speculative_config) = (
                maybe_override_with_speculators(
                    model=self.model,
                    tokenizer=self.tokenizer,
                    revision=self.revision,
                    trust_remote_code=self.trust_remote_code,
                    vllm_speculative_config=self.speculative_config,
1596
                    hf_token=self.hf_token,
1597
1598
1599
                )
            )

1600
        model_config = self.create_model_config()
1601
        self.model = model_config.model
1602
        self.model_weights = model_config.model_weights
1603
1604
        self.tokenizer = model_config.tokenizer

1605
        self._check_feature_supported()
1606
        self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
1607
        self._set_default_reasoning_config_args()
1608
        sliding_window: int | None = None
1609
1610
1611
1612
1613
1614
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1615
1616
1617
1618
1619
        # Resolve "auto" kv_cache_dtype to actual value from model config
        resolved_cache_dtype = resolve_kv_cache_dtype_string(
            self.kv_cache_dtype, model_config
        )

1620
1621
1622
1623
        assert self.enable_prefix_caching is not None, (
            "enable_prefix_caching must be set by this point"
        )

1624
        cache_config = CacheConfig(
1625
            block_size=self.block_size,  # type: ignore[arg-type]
1626
            gpu_memory_utilization=self.gpu_memory_utilization,
1627
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1628
            cache_dtype=resolved_cache_dtype,  # type: ignore[arg-type]
1629
            is_attention_free=model_config.is_attention_free,
1630
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1631
            sliding_window=sliding_window,
1632
            enable_prefix_caching=self.enable_prefix_caching,
1633
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1634
            calculate_kv_scales=self.calculate_kv_scales,
1635
            kv_cache_dtype_skip_layers=self.kv_cache_dtype_skip_layers,
1636
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1637
1638
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1639
            mamba_block_size=self.mamba_block_size,
1640
            mamba_cache_mode=self.mamba_cache_mode,
1641
1642
            kv_offloading_size=self.kv_offloading_size,
            kv_offloading_backend=self.kv_offloading_backend,
1643
        )
1644

1645
1646
1647
1648
1649
1650
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1651

1652
            ray_runtime_env = ray.get_runtime_context().runtime_env
1653
1654
1655
1656
1657
1658
1659
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1660

1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1672
        assert not headless or not self.data_parallel_hybrid_lb, (
1673
1674
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
        assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
            "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
        )
        assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
            "nnodes > 1 is only supported with data_parallel_backend=mp"
        )
        inferred_data_parallel_rank = 0
        if self.nnodes > 1:
            world_size = (
                self.data_parallel_size
                * self.pipeline_parallel_size
                * self.tensor_parallel_size
            )
            world_size_within_dp = (
                self.pipeline_parallel_size * self.tensor_parallel_size
            )
            local_world_size = world_size // self.nnodes
            assert world_size % self.nnodes == 0, (
                f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
            )
            assert self.node_rank < self.nnodes, (
                f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
            )
            inferred_data_parallel_rank = (
                self.node_rank * local_world_size
            ) // world_size_within_dp
            if self.data_parallel_size > 1 and self.data_parallel_external_lb:
                self.data_parallel_rank = inferred_data_parallel_rank
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d for external lb",
                    self.data_parallel_rank,
                    self.node_rank,
                )
            elif self.data_parallel_size_local is None:
                # Infer data parallel size local for internal dplb:
                self.data_parallel_size_local = max(
                    local_world_size // world_size_within_dp, 1
                )
        data_parallel_external_lb = (
            self.data_parallel_external_lb or self.data_parallel_rank is not None
        )
1716
        # Local DP rank = 1, use pure-external LB.
1717
        if data_parallel_external_lb:
1718
            assert self.data_parallel_rank is not None, (
1719
                "data_parallel_rank or node_rank must be specified if "
1720
1721
                "data_parallel_external_lb is enable."
            )
1722
            assert self.data_parallel_size_local in (1, None), (
1723
1724
                "data_parallel_size_local must be 1 or None when data_parallel_rank "
                "is set"
1725
            )
1726
            data_parallel_size_local = 1
1727
1728
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1729
1730
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1731
1732
1733
1734
1735
1736
1737

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
1738
1739
1740
1741
1742
                logger.warning(
                    "data_parallel_hybrid_lb is not eligible when "
                    "data_parallel_size_local = 1, autoswitch to "
                    "data_parallel_external_lb."
                )
1743
1744
1745
1746
1747
1748
1749
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

1750
1751
1752
1753
1754
1755
1756
1757
1758
            self.data_parallel_rank = (
                self.data_parallel_start_rank or inferred_data_parallel_rank
            )
            if self.nnodes > 1:
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d",
                    self.data_parallel_rank,
                    self.node_rank,
                )
1759
        else:
1760
            assert not self.data_parallel_hybrid_lb, (
1761
1762
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1763

1764
1765
1766
1767
1768
1769
1770
1771
1772
            if self.data_parallel_backend == "ray" and (
                envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
            ):
                # Data parallel size defaults to 1 if DP ranks are spanning
                # multiple nodes
                data_parallel_size_local = 1
            else:
                # Otherwise local DP size defaults to global DP size if not set
                data_parallel_size_local = self.data_parallel_size
1773
1774
1775

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1776
1777
1778
1779
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1780
1781
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1782
1783
1784
1785
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1786
1787
                    self.data_parallel_backend,
                )
1788
1789
1790
                data_parallel_address = (
                    self.master_addr or ParallelConfig.data_parallel_master_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1791
1792
        else:
            data_parallel_address = self.data_parallel_address
1793
1794
1795

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1796
        data_parallel_rpc_port = (
1797
            self.data_parallel_rpc_port
1798
1799
1800
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1801

1802
1803
1804
1805
        if self.tokens_only and not model_config.skip_tokenizer_init:
            model_config.skip_tokenizer_init = True
            logger.info("Skipping tokenizer initialization for tokens-only mode.")

1806
        parallel_config = ParallelConfig(
1807
1808
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1809
            prefill_context_parallel_size=self.prefill_context_parallel_size,
1810
            data_parallel_size=self.data_parallel_size,
1811
1812
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1813
            data_parallel_size_local=data_parallel_size_local,
1814
1815
1816
1817
            master_addr=self.master_addr,
            master_port=self.master_port,
            nnodes=self.nnodes,
            node_rank=self.node_rank,
1818
            distributed_timeout_seconds=self.distributed_timeout_seconds,
1819
1820
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1821
            data_parallel_backend=self.data_parallel_backend,
1822
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1823
            is_moe_model=model_config.is_moe,
1824
            enable_expert_parallel=self.enable_expert_parallel,
1825
            enable_ep_weight_filter=self.enable_ep_weight_filter,
1826
            all2all_backend=self.all2all_backend,
1827
            enable_elastic_ep=self.enable_elastic_ep,
1828
            enable_dbo=self.enable_dbo,
1829
            ubatch_size=self.ubatch_size,
1830
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1831
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1832
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1833
            enable_eplb=self.enable_eplb,
1834
            eplb_config=self.eplb_config,
1835
            expert_placement_strategy=self.expert_placement_strategy,
1836
1837
1838
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1839
            ray_runtime_env=ray_runtime_env,
1840
            placement_group=placement_group,
1841
1842
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1843
            worker_extension_cls=self.worker_extension_cls,
1844
            decode_context_parallel_size=self.decode_context_parallel_size,
1845
            dcp_comm_backend=self.dcp_comm_backend,
1846
            dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
1847
            cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
1848
1849
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1850
1851
1852
            numa_bind=self.numa_bind,
            numa_bind_nodes=self.numa_bind_nodes,
            numa_bind_cpus=self.numa_bind_cpus,
1853
        )
1854

1855
        speculative_config = self.create_speculative_config(
1856
1857
1858
1859
            target_model_config=model_config,
            target_parallel_config=parallel_config,
        )

1860
1861
1862
1863
1864
1865
        self._set_default_max_num_seqs_and_batched_tokens_args(
            usage_context,
            model_config,
            parallel_config,
        )

1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
        assert self.max_num_batched_tokens is not None, (
            "max_num_batched_tokens must be set by this point"
        )
        assert self.max_num_seqs is not None, "max_num_seqs must be set by this point"
        assert self.enable_chunked_prefill is not None, (
            "enable_chunked_prefill must be set by this point"
        )
        assert model_config.max_model_len is not None, (
            "max_model_len must be set by this point"
        )
1876
        scheduler_config = SchedulerConfig(
1877
            runner_type=model_config.runner_type,
1878
1879
1880
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1881
            enable_chunked_prefill=self.enable_chunked_prefill,
1882
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1883
            is_multimodal_model=model_config.is_multimodal_model,
1884
            is_encoder_decoder=model_config.is_encoder_decoder,
1885
            policy=self.scheduling_policy,
1886
            scheduler_cls=self.scheduler_cls,
1887
1888
1889
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1890
            scheduler_reserve_full_isl=self.scheduler_reserve_full_isl,
1891
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1892
            async_scheduling=self.async_scheduling,
1893
            stream_interval=self.stream_interval,
1894
        )
1895

1896
1897
1898
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_dtype=self.lora_dtype,
1909
                target_modules=self.lora_target_modules,
1910
                enable_tower_connector_lora=self.enable_tower_connector_lora,
1911
                specialize_active_lora=self.specialize_active_lora,
1912
1913
1914
1915
1916
1917
1918
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1919

1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
        if (
            lora_config is not None
            and speculative_config is not None
            and scheduler_config.max_num_batched_tokens
            < (
                scheduler_config.max_num_seqs
                * (speculative_config.num_speculative_tokens + 1)
            )
        ):
            raise ValueError(
                "Consider increasing max_num_batched_tokens or "
                "decreasing num_speculative_tokens"
            )

1934
1935
1936
1937
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1938
1939
1940
1941
1942
1943
1944
1945
        # Attention config overrides
        attention_config = copy.deepcopy(self.attention_config)
        if self.attention_backend is not None:
            if attention_config.backend is not None:
                raise ValueError(
                    "attention_backend and attention_config.backend "
                    "are mutually exclusive"
                )
1946
1947
1948
1949
            # Reuse the validator to handle "auto" and string-to-enum conversion
            attention_config.backend = AttentionConfig.validate_backend_before(
                self.attention_backend
            )
1950

1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
        # Mamba config overrides
        mamba_config = copy.deepcopy(self.mamba_config)
        # Convert string to enum if needed (CLI parsing returns a string)
        if isinstance(self.mamba_backend, str):
            mamba_config.backend = MambaBackendEnum[self.mamba_backend.upper()]
        else:
            mamba_config.backend = self.mamba_backend
        if self.enable_mamba_cache_stochastic_rounding:
            mamba_config.enable_stochastic_rounding = (
                self.enable_mamba_cache_stochastic_rounding
            )
        if self.mamba_cache_philox_rounds:
            mamba_config.stochastic_rounding_philox_rounds = (
                self.mamba_cache_philox_rounds
            )

1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
        # Kernel config overrides
        kernel_config = copy.deepcopy(self.kernel_config)
        if self.enable_flashinfer_autotune is not None:
            if kernel_config.enable_flashinfer_autotune is not None:
                raise ValueError(
                    "enable_flashinfer_autotune and "
                    "kernel_config.enable_flashinfer_autotune "
                    "are mutually exclusive"
                )
            kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune
1977
1978
        if self.moe_backend != "auto":
            kernel_config.moe_backend = self.moe_backend
1979

1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
        # Transfer top-level ir_op_priority into KernelConfig.ir_op_priority
        for op_name, op_priority in asdict(self.ir_op_priority).items():
            # Empty means unset
            if not op_priority:
                continue

            # Priority cannot be set 2x for the same op
            if getattr(kernel_config.ir_op_priority, op_name):
                raise ValueError(
                    f"Op priority for {op_name} specified via both ir_op_priority "
                    f"and KernelConfig.ir_op_priority, only one allowed at a time."
                )

            # Set the attribute
            setattr(kernel_config.ir_op_priority, op_name, op_priority)

1996
        load_config = self.create_load_config()
1997

1998
1999
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
2000
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
2001

2002
2003
2004
2005
2006
        if self.reasoning_parser_plugin:
            self.structured_outputs_config.reasoning_parser_plugin = (
                self.reasoning_parser_plugin
            )

2007
        observability_config = ObservabilityConfig(
2008
            show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
2009
            otlp_traces_endpoint=self.otlp_traces_endpoint,
2010
            collect_detailed_traces=self.collect_detailed_traces,
2011
2012
            kv_cache_metrics=self.kv_cache_metrics,
            kv_cache_metrics_sample=self.kv_cache_metrics_sample,
2013
            cudagraph_metrics=self.cudagraph_metrics,
2014
            enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
2015
            enable_mfu_metrics=self.enable_mfu_metrics,
2016
            enable_mm_processor_stats=self.enable_mm_processor_stats,
2017
            enable_logging_iteration_details=self.enable_logging_iteration_details,
2018
        )
2019

2020
        # Compilation config overrides
2021
        compilation_config = copy.deepcopy(self.compilation_config)
2022
        if self.cudagraph_capture_sizes is not None:
2023
            if compilation_config.cudagraph_capture_sizes is not None:
2024
2025
2026
2027
                raise ValueError(
                    "cudagraph_capture_sizes and compilation_config."
                    "cudagraph_capture_sizes are mutually exclusive"
                )
2028
            compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes
2029
        if self.max_cudagraph_capture_size is not None:
2030
            if compilation_config.max_cudagraph_capture_size is not None:
2031
2032
2033
2034
                raise ValueError(
                    "max_cudagraph_capture_size and compilation_config."
                    "max_cudagraph_capture_size are mutually exclusive"
                )
2035
            compilation_config.max_cudagraph_capture_size = (
2036
2037
                self.max_cudagraph_capture_size
            )
2038
2039

        offload_config = OffloadConfig(
2040
            offload_backend=self.offload_backend,
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
            uva=UVAOffloadConfig(
                cpu_offload_gb=self.cpu_offload_gb,
                cpu_offload_params=self.cpu_offload_params,
            ),
            prefetch=PrefetchOffloadConfig(
                offload_group_size=self.offload_group_size,
                offload_num_in_group=self.offload_num_in_group,
                offload_prefetch_step=self.offload_prefetch_step,
                offload_params=self.offload_params,
            ),
        )

2053
2054
2055
        if self.gdn_prefill_backend is not None:
            self.additional_config["gdn_prefill_backend"] = self.gdn_prefill_backend

2056
        config = VllmConfig(
2057
2058
2059
2060
2061
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
2062
            load_config=load_config,
2063
            offload_config=offload_config,
2064
            attention_config=attention_config,
2065
            mamba_config=mamba_config,
2066
            kernel_config=kernel_config,
2067
2068
            lora_config=lora_config,
            speculative_config=speculative_config,
2069
            structured_outputs_config=self.structured_outputs_config,
2070
            observability_config=observability_config,
2071
            compilation_config=compilation_config,
2072
            kv_transfer_config=self.kv_transfer_config,
2073
            kv_events_config=self.kv_events_config,
2074
            ec_transfer_config=self.ec_transfer_config,
2075
            reasoning_config=self.reasoning_config,
2076
            profiler_config=self.profiler_config,
2077
            additional_config=self.additional_config,
2078
            optimization_level=self.optimization_level,
2079
            performance_mode=self.performance_mode,
2080
            weight_transfer_config=self.weight_transfer_config,
2081
            shutdown_timeout=self.shutdown_timeout,
2082
        )
2083

2084
2085
        return config

2086
    def _check_feature_supported(self):
2087
        """Raise an error if the feature is not supported."""
2088
        # No Concurrent Partial Prefills so far.
2089
2090
2091
2092
2093
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
2094
            _raise_unsupported_error(feature_name="Concurrent Partial Prefill")
2095

2096
        if self.pipeline_parallel_size > 1:
2097
2098
2099
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
2100
            if not supports_pp and self.distributed_executor_backend not in (
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
2111
                _raise_unsupported_error(feature_name=name)
2112

2113
2114
2115
2116
2117
2118
    @classmethod
    def get_batch_defaults(
        cls,
        world_size: int,
    ) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
        from vllm.usage.usage_lib import UsageContext
2119

2120
2121
        default_max_num_batched_tokens: dict[UsageContext | None, int]
        default_max_num_seqs: dict[UsageContext | None, int]
2122

2123
2124
        # When no user override, set the default values based on the usage
        # context.
2125
        # Use different default values for different hardware.
2126
2127
2128
2129
2130
2131
2132

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
2133
            device_memory = current_platform.get_device_total_memory()
2134
            device_name = current_platform.get_device_name().lower()
2135
2136
        except Exception:
            # This is only used to set default_max_num_batched_tokens
2137
            device_memory = 0
2138
            device_name = ""
2139

2140
2141
2142
2143
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
2144
            # For GPUs like H100 and MI300x, use larger default values.
2145
2146
2147
2148
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
2149
2150
2151
2152
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
2153
2154
2155
2156
2157
2158
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
2159
2160
2161
2162
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
2163

2164
2165
        # tpu specific default values.
        if current_platform.is_tpu():
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
            chip_name = current_platform.get_device_name()

            if chip_name == "V6E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 2048,
                    UsageContext.OPENAI_API_SERVER: 1024,
                }
            elif chip_name == "V5E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 1024,
                    UsageContext.OPENAI_API_SERVER: 512,
                }
            elif chip_name == "V5P":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 512,
                    UsageContext.OPENAI_API_SERVER: 256,
                }
2183

2184
2185
2186
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
2187
2188
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
2189
2190
            }
            default_max_num_seqs = {
2191
2192
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
2193
2194
            }

2195
2196
        return default_max_num_batched_tokens, default_max_num_seqs

2197
2198
    def _set_default_chunked_prefill_and_prefix_caching_args(
        self, model_config: ModelConfig
2199
    ) -> None:
2200
2201
        default_chunked_prefill = model_config.is_chunked_prefill_supported
        default_prefix_caching = model_config.is_prefix_caching_supported
2202
2203
2204
2205
2206
2207
2208
2209

        if self.enable_chunked_prefill is None:
            self.enable_chunked_prefill = default_chunked_prefill

            logger.debug(
                "%s chunked prefill by default",
                "Enabling" if default_chunked_prefill else "Disabling",
            )
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
        elif (
            model_config.runner_type == "generate"
            and not self.enable_chunked_prefill
            and default_chunked_prefill
        ):
            logger.warning_once(
                "This model does not officially support disabling chunked prefill. "
                "Disabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
                scope="local",
            )
2221
2222
2223
2224
        elif (
            model_config.runner_type == "pooling"
            and self.enable_chunked_prefill
            and not default_chunked_prefill
2225
        ):
2226
            logger.warning_once(
2227
2228
2229
                "This model does not officially support chunked prefill. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
2230
                scope="local",
2231
2232
2233
2234
2235
            )

        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = default_prefix_caching

2236
            logger.debug(
2237
2238
2239
2240
2241
2242
2243
2244
                "%s prefix caching by default",
                "Enabling" if default_prefix_caching else "Disabling",
            )
        elif (
            model_config.runner_type == "pooling"
            and self.enable_prefix_caching
            and not default_prefix_caching
        ):
2245
            logger.warning_once(
2246
2247
2248
                "This model does not officially support prefix caching. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
2249
                scope="local",
2250
2251
            )

2252
        # Disable chunked prefill and prefix caching for:
2253
        # RISCV CPUs in V1
2254
2255
2256
2257
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.RISCV,
        ):
            logger.info(
2258
2259
                "Chunked prefill is not supported for"
                "RISC-V CPUs; "
2260
2261
2262
2263
                "disabling it for V1 backend."
            )
            self.enable_chunked_prefill = False
            logger.info(
2264
2265
                "Prefix caching is not supported for "
                "RISC-V CPUs; "
2266
2267
2268
2269
                "disabling it for V1 backend."
            )
            self.enable_prefix_caching = False

2270
2271
2272
2273
2274
2275
2276
    def _set_default_reasoning_config_args(self):
        if not self.reasoning_parser:
            return
        if self.reasoning_config is None:
            self.reasoning_config = ReasoningConfig()
        self.reasoning_config.reasoning_parser = self.reasoning_parser

2277
    def _set_default_max_num_seqs_and_batched_tokens_args(
2278
2279
2280
        self,
        usage_context: UsageContext | None,
        model_config: ModelConfig,
2281
        parallel_config: ParallelConfig,
2282
    ):
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
        world_size = self.pipeline_parallel_size * self.tensor_parallel_size
        (
            default_max_num_batched_tokens,
            default_max_num_seqs,
        ) = self.get_batch_defaults(world_size)

        orig_max_num_batched_tokens = self.max_num_batched_tokens
        orig_max_num_seqs = self.max_num_seqs

        if self.max_num_batched_tokens is None:
2293
2294
2295
2296
2297
2298
2299
2300
2301
            if parallel_config.use_batched_dp_moe:
                self.max_num_batched_tokens = (
                    SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP
                )
            else:
                self.max_num_batched_tokens = default_max_num_batched_tokens.get(
                    usage_context,
                    SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
                )
2302
2303
2304
2305
2306
2307
2308

        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
            )

2309
2310
2311
2312
2313
2314
2315
        # If throughput mode is set, double max_num_batched_tokens and max_num_seqs.
        if self.performance_mode == "throughput":
            if orig_max_num_batched_tokens is None:
                self.max_num_batched_tokens *= 2
            if orig_max_num_seqs is None:
                self.max_num_seqs *= 2

2316
        if orig_max_num_batched_tokens is None:
2317
2318
2319
            assert model_config.max_model_len is not None, (
                "max_model_len must be set by this point"
            )
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
            if not self.enable_chunked_prefill:
                # If max_model_len is too short, use the default for higher throughput.
                self.max_num_batched_tokens = max(
                    model_config.max_model_len,
                    self.max_num_batched_tokens,
                )

            # When using default settings,
            # Ensure max_num_batched_tokens does not exceed model limit.
            # Some models (e.g., Whisper) have embeddings tied to max length.
            self.max_num_batched_tokens = min(
                self.max_num_seqs * model_config.max_model_len,
2332
2333
                self.max_num_batched_tokens,
            )
2334

2335
2336
2337
2338
            logger.debug(
                "Defaulting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens,
                usage_context.value if usage_context else None,
2339
            )
2340

2341
2342
2343
2344
        if orig_max_num_seqs is None:
            assert self.max_num_batched_tokens is not None  # For type checking
            self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

2345
            logger.debug(
2346
                "Defaulting max_num_seqs to %d for %s usage context.",
2347
                self.max_num_seqs,
2348
                usage_context.value if usage_context else None,
2349
            )
2350

2351

2352
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
2353
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
2354
    """Arguments for asynchronous vLLM engine."""
2355

2356
2357
    enable_log_requests: bool = False

2358
    @staticmethod
2359
2360
2361
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
2362
        # Initialize plugin to update the parser, for example, The plugin may
2363
        # add a new kind of quantization method to --quantization argument or
2364
2365
        # a new device to --device argument.
        load_general_plugins()
2366
2367
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
2368
2369
2370
2371
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
2372
            help="Enable logging request information, dependent on log level:\n"
2373
2374
2375
            "- INFO: Request ID, parameters and LoRA request.\n"
            "- DEBUG: Prompt inputs (e.g: text, token IDs).\n"
            "You can set the minimum log level via `VLLM_LOGGING_LEVEL`.",
2376
        )
2377
        current_platform.pre_register_and_update(parser)
2378
        return parser
2379
2380


2381
2382
2383
2384
2385
2386
def _raise_unsupported_error(feature_name: str):
    msg = (
        f"{feature_name} is not supported. We recommend to "
        f"remove {feature_name} from your config."
    )
    raise NotImplementedError(msg)
2387
2388


2389
def human_readable_int(value: str) -> int:
2390
2391
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
2392

2393
2394
2395
2396
2397
2398
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
2399

2400
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
2401
2402
    if match:
        decimal_multiplier = {
2403
2404
2405
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
2406
            "t": 10**12,
2407
2408
        }
        binary_multiplier = {
2409
2410
2411
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
2412
            "T": 2**40,
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
2425
2426
2427
2428
2429
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
2430
2431
2432

    # Regular plain number.
    return int(value)
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451


def human_readable_int_or_auto(value: str) -> int:
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    Also accepts -1 or 'auto' as a special value for auto-detection.

    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    - '-1' or 'auto' -> -1 (special value for auto-detection)
    """
    value = value.strip()

    if value == "-1" or value.lower() == "auto":
        return -1

    return human_readable_int(value)