arg_utils.py 91.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from collections.abc import Callable
11
from dataclasses import MISSING, dataclass, fields, is_dataclass
12
from itertools import permutations
13
from types import UnionType
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Literal,
19
    TypeAlias,
20
21
22
23
24
25
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
26

27
import huggingface_hub
28
import regex as re
29
import torch
30
from pydantic import TypeAdapter, ValidationError
31
from pydantic.fields import FieldInfo
32
from typing_extensions import TypeIs
33

34
import vllm.envs as envs
35
from vllm.config import (
36
    AttentionConfig,
37
38
39
40
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
41
    ECTransferConfig,
42
    EPLBConfig,
43
    KernelConfig,
44
45
46
47
48
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
    ModelConfig,
49
    MultiModalConfig,
50
    ObservabilityConfig,
51
    OffloadConfig,
52
53
    ParallelConfig,
    PoolerConfig,
54
    PrefetchOffloadConfig,
55
    ProfilerConfig,
56
57
58
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
59
    UVAOffloadConfig,
60
    VllmConfig,
61
    WeightTransferConfig,
62
63
    get_attr_docs,
)
64
from vllm.config.cache import (
65
    BlockSize,
66
67
    CacheDType,
    KVOffloadingBackend,
68
    MambaCacheMode,
69
70
71
    MambaDType,
    PrefixCachingHashAlgo,
)
72
from vllm.config.device import Device
73
from vllm.config.lora import MaxLoRARanks
74
75
76
77
78
79
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
80
    TokenizerMode,
81
82
83
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.observability import DetailedTraceModules
84
85
86
87
88
89
from vllm.config.parallel import (
    All2AllBackend,
    DataParallelBackend,
    DistributedExecutorBackend,
    ExpertPlacementStrategy,
)
90
from vllm.config.scheduler import SchedulerPolicy
91
from vllm.config.utils import get_field
92
from vllm.config.vllm import OptimizationLevel
93
from vllm.logger import init_logger, suppress_logging
94
from vllm.platforms import CpuArchEnum, current_platform
95
from vllm.plugins import load_general_plugins
96
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
97
98
99
100
from vllm.transformers_utils.config import (
    is_interleaved,
    maybe_override_with_speculators,
)
101
from vllm.transformers_utils.gguf_utils import is_gguf
102
from vllm.transformers_utils.repo_utils import get_model_path
103
from vllm.transformers_utils.utils import is_cloud_storage
104
from vllm.utils.argparse_utils import FlexibleArgumentParser
105
from vllm.utils.mem_constants import GiB_bytes
106
from vllm.utils.network_utils import get_ip
107
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
108
from vllm.v1.attention.backends.registry import AttentionBackendEnum
109
from vllm.v1.sample.logits_processor import LogitsProcessor
110

111
112
if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
113
    from vllm.model_executor.model_loader import LoadFormats
114
    from vllm.usage.usage_lib import UsageContext
115
    from vllm.v1.executor import Executor
116
else:
117
    Executor = Any
118
    QuantizationMethods = Any
119
    LoadFormats = Any
120
121
    UsageContext = Any

122

123
124
logger = init_logger(__name__)

125
126
# object is used to allow for special typing forms
T = TypeVar("T")
127
128
TypeHint: TypeAlias = type[Any] | object
TypeHintT: TypeAlias = type[T] | object
129

130

131
132
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
133
134
135
136
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
137
138
                f"Value {val} cannot be converted to {return_type}."
            ) from e
139

140
141
142
    return _parse_type


143
144
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]:
    def _optional_type(val: str) -> T | None:
145
146
147
148
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

149
    return _optional_type
150
151


152
def union_dict_and_str(val: str) -> str | dict[str, str] | None:
153
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
154
        return str(val)
155
    return optional_type(json.loads)(val)
156
157


158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


173
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
174
175
176
177
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
178
    type_hint = get_type(type_hints, Literal)
179
180
181
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
182
        raise ValueError(
183
            "All options must be of the same type. "
184
185
            f"Got {options} with types {[type(c) for c in options]}"
        )
186
187
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
188
189


190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]:
    type_hint = get_type(type_hints, type)
    types = get_args(type_hint)
    elem_type = types[0]

    # Handle Ellipsis
    assert all(t is elem_type for t in types if t is not Ellipsis), (
        f"All non-Ellipsis elements must be of the same type. Got {types}."
    )

    # Handle Union types
    if get_origin(elem_type) in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
        assert str in get_args(elem_type), (
            "If element can have multiple types, one must be 'str' "
            f"(i.e. 'list[int | str]'). Got {elem_type}."
        )
        elem_type = str

    return {
        "type": elem_type,
        "nargs": "+" if type is not tuple or Ellipsis in types else len(types),
    }


215
216
217
218
219
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


220
221
222
223
224
225
226
227
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
228
229
    elif origin in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
230
231
232
233
234
235
236
237
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


238
NEEDS_HELP = (
239
240
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
241
242
243
244
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


245
@functools.lru_cache(maxsize=30)
246
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
247
248
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
249
250
    kwargs = {}
    for field in fields(cls):
251
        # Get the set of possible types for the field
252
        type_hints: set[TypeHint] = get_type_hints(field.type)
253
254
255
256
257

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

258
        # Get the default value of the field
259
260
        if field.default is not MISSING:
            default = field.default
261
262
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
263
264
265
266
267
268
                if default.default_factory is None:
                    default = default.default
                else:
                    # VllmConfig's Fields have default_factory set to config classes.
                    # These could emit logs on init, which would be confusing.
                    with suppress_logging():
269
                        default = default.default_factory()  # type: ignore[call-arg]
270
        elif field.default_factory is not MISSING:
271
            default = field.default_factory()
272
273
274

        # Get the help text for the field
        name = field.name
275
        help = cls_docs.get(name, "").strip()
276
277
278
279
280
281
282
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
283
284
285
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
286
        if dataclass_cls is not None:
287
288
289
290
291
292
293
294

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
295
            kwargs[name]["help"] += f"\n\n{json_tip}"
296
        elif contains_type(type_hints, bool):
297
298
299
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
300
            kwargs[name].update(literal_to_kwargs(type_hints))
301
        elif contains_type(type_hints, tuple):
302
            kwargs[name].update(collection_to_kwargs(type_hints, tuple))
303
        elif contains_type(type_hints, list):
304
305
306
            kwargs[name].update(collection_to_kwargs(type_hints, list))
        elif contains_type(type_hints, set):
            kwargs[name].update(collection_to_kwargs(type_hints, set))
307
        elif contains_type(type_hints, int):
308
309
310
311
            if name == "max_model_len":
                kwargs[name]["type"] = human_readable_int_or_auto
                kwargs[name]["help"] += f"\n\n{human_readable_int_or_auto.__doc__}"
            elif name in ("max_num_batched_tokens", "kv_cache_memory_bytes"):
312
                kwargs[name]["type"] = human_readable_int
313
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
314
315
            else:
                kwargs[name]["type"] = int
316
317
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
318
319
320
321
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
322
            kwargs[name]["type"] = union_dict_and_str
323
        elif contains_type(type_hints, dict):
324
            kwargs[name]["type"] = parse_type(json.loads)
325
            kwargs[name]["help"] += f"\n\n{json_tip}"
326
327
328
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
329
330
            kwargs[name]["type"] = str
        else:
331
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
332

333
334
335
336
337
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

338
339
340
341
342
343
344
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
345
346


347
def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
348
349
    """Return argparse kwargs for the given Config dataclass.

350
351
352
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

353
354
355
356
357
358
359
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


360
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
361
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
362
    """Arguments for vLLM engine."""
363

364
    model: str = ModelConfig.model
365
    enable_return_routed_experts: bool = ModelConfig.enable_return_routed_experts
366
    model_weights: str = ModelConfig.model_weights
367
    served_model_name: str | list[str] | None = ModelConfig.served_model_name
368
    tokenizer: str | None = ModelConfig.tokenizer
369
    hf_config_path: str | None = ModelConfig.hf_config_path
370
371
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
372
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
373
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
374
    tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
375
    trust_remote_code: bool = ModelConfig.trust_remote_code
376
377
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
    allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
378
    download_dir: str | None = LoadConfig.download_dir
379
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
380
    load_format: str | LoadFormats = LoadConfig.load_format
381
382
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
383
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
384
    seed: int = ModelConfig.seed
385
    max_model_len: int = ModelConfig.max_model_len
386
387
388
389
390
391
    cudagraph_capture_sizes: list[int] | None = (
        CompilationConfig.cudagraph_capture_sizes
    )
    max_cudagraph_capture_size: int | None = get_field(
        CompilationConfig, "max_cudagraph_capture_size"
    )
392
393
394
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
395
    distributed_executor_backend: (
396
        str | DistributedExecutorBackend | type[Executor] | None
397
    ) = ParallelConfig.distributed_executor_backend
398
    # number of P/D disaggregation (or other disaggregation) workers
399
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
400
401
402
403
    master_addr: str = ParallelConfig.master_addr
    master_port: int = ParallelConfig.master_port
    nnodes: int = ParallelConfig.nnodes
    node_rank: int = ParallelConfig.node_rank
404
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
405
    prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
406
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
407
    dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
408
    cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
409
    data_parallel_size: int = ParallelConfig.data_parallel_size
410
411
412
413
414
    data_parallel_rank: int | None = None
    data_parallel_start_rank: int | None = None
    data_parallel_size_local: int | None = None
    data_parallel_address: str | None = None
    data_parallel_rpc_port: int | None = None
415
    data_parallel_hybrid_lb: bool = False
416
    data_parallel_external_lb: bool = False
417
    data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend
418
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
419
    all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
420
    enable_dbo: bool = ParallelConfig.enable_dbo
421
    ubatch_size: int = ParallelConfig.ubatch_size
422
423
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
424
    disable_nccl_for_dp_synchronization: bool | None = (
425
426
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
427
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
428
    enable_eplb: bool = ParallelConfig.enable_eplb
429
    expert_placement_strategy: ExpertPlacementStrategy = (
430
        ParallelConfig.expert_placement_strategy
431
    )
432
433
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
434
    max_parallel_loading_workers: int | None = (
435
436
        ParallelConfig.max_parallel_loading_workers
    )
437
    block_size: BlockSize = CacheConfig.block_size
438
    enable_prefix_caching: bool | None = None
439
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
440
        CacheConfig.prefix_caching_hash_algo
441
    )
442
443
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
444
    swap_space: float = CacheConfig.swap_space
445
446
447
448
449
450
451
    offload_backend: str = OffloadConfig.offload_backend
    cpu_offload_gb: float = UVAOffloadConfig.cpu_offload_gb
    cpu_offload_params: set[str] = get_field(UVAOffloadConfig, "cpu_offload_params")
    offload_group_size: int = PrefetchOffloadConfig.offload_group_size
    offload_num_in_group: int = PrefetchOffloadConfig.offload_num_in_group
    offload_prefetch_step: int = PrefetchOffloadConfig.offload_prefetch_step
    offload_params: set[str] = get_field(PrefetchOffloadConfig, "offload_params")
452
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
453
    kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
454
    max_num_batched_tokens: int | None = None
455
456
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
457
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
458
    max_num_seqs: int | None = None
459
    max_logprobs: int = ModelConfig.max_logprobs
460
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
461
    disable_log_stats: bool = False
462
    aggregate_engine_logging: bool = False
463
464
465
    revision: str | None = ModelConfig.revision
    code_revision: str | None = ModelConfig.code_revision
    hf_token: bool | str | None = ModelConfig.hf_token
466
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
467
    tokenizer_revision: str | None = ModelConfig.tokenizer_revision
468
    quantization: QuantizationMethods | str | None = ModelConfig.quantization
469
    allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
470
    enforce_eager: bool = ModelConfig.enforce_eager
471
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
472
    language_model_only: bool = MultiModalConfig.language_model_only
473
    limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
474
475
        MultiModalConfig, "limit_per_prompt"
    )
476
    enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
477
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
478
479
480
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
481
    mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
482
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
483
    mm_processor_cache_type: MMCacheType | None = (
484
        MultiModalConfig.mm_processor_cache_type
485
486
    )
    mm_shm_cache_max_object_size_mb: int = (
487
        MultiModalConfig.mm_shm_cache_max_object_size_mb
488
    )
489
    mm_encoder_only: bool = MultiModalConfig.mm_encoder_only
490
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
491
    mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
492
493
        MultiModalConfig.mm_encoder_attn_backend
    )
494
    io_processor_plugin: str | None = None
495
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
496
    video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
497
    # LoRA fields
498
    enable_lora: bool = False
499
    max_loras: int = LoRAConfig.max_loras
500
    max_lora_rank: MaxLoRARanks = LoRAConfig.max_lora_rank
501
    default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras
502
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
503
504
    max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
    lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
505
    enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
506
    specialize_active_lora: bool = LoRAConfig.specialize_active_lora
507

508
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
509
    num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
510
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
511
    ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
512

513
    enable_chunked_prefill: bool | None = None
514
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
515

516
    disable_hybrid_kv_cache_manager: bool | None = (
517
518
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
519

520
    structured_outputs_config: StructuredOutputsConfig = get_field(
521
522
        VllmConfig, "structured_outputs_config"
    )
523
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
524
    reasoning_parser_plugin: str | None = None
525

526
    speculative_config: dict[str, Any] | None = None
527

528
    show_hidden_metrics_for_version: str | None = (
529
        ObservabilityConfig.show_hidden_metrics_for_version
530
    )
531
532
    otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: list[DetailedTraceModules] | None = (
533
        ObservabilityConfig.collect_detailed_traces
534
    )
535
536
537
538
    kv_cache_metrics: bool = ObservabilityConfig.kv_cache_metrics
    kv_cache_metrics_sample: float = get_field(
        ObservabilityConfig, "kv_cache_metrics_sample"
    )
539
    cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
540
541
542
    enable_layerwise_nvtx_tracing: bool = (
        ObservabilityConfig.enable_layerwise_nvtx_tracing
    )
543
    enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
544
545
546
    enable_logging_iteration_details: bool = (
        ObservabilityConfig.enable_logging_iteration_details
    )
547
    enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
548
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
549
    scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
550

551
    pooler_config: PoolerConfig | None = ModelConfig.pooler_config
552
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
553
    attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
554
555
556
557
    kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
    enable_flashinfer_autotune: bool = get_field(
        KernelConfig, "enable_flashinfer_autotune"
    )
558
559
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
560

561
562
    profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")

563
564
    kv_transfer_config: KVTransferConfig | None = None
    kv_events_config: KVEventsConfig | None = None
565

566
567
    ec_transfer_config: ECTransferConfig | None = None

568
569
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
570
571
572
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
573
    model_impl: str = ModelConfig.model_impl
574
    override_attention_dtype: str | None = ModelConfig.override_attention_dtype
575
    attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
576

577
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
578
579
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
580
    mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
581
    mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
582

583
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
584

585
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
586
    pt_load_map_location: str | dict[str, str] = LoadConfig.pt_load_map_location
587

588
    logits_processors: list[str | type[LogitsProcessor]] | None = (
589
590
        ModelConfig.logits_processors
    )
591
592
    """Custom logitproc types"""

593
    async_scheduling: bool | None = SchedulerConfig.async_scheduling
594

595
596
    stream_interval: int = SchedulerConfig.stream_interval

597
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
598
    optimization_level: OptimizationLevel = VllmConfig.optimization_level
599

600
    kv_offloading_size: float | None = CacheConfig.kv_offloading_size
601
    kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
602
    tokens_only: bool = False
603

604
605
606
607
    weight_transfer_config: WeightTransferConfig | None = get_field(
        VllmConfig,
        "weight_transfer_config",
    )
608

609
610
    fail_on_environ_validation: bool = False

611
    def __post_init__(self):
612
613
614
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
615
        if isinstance(self.compilation_config, dict):
616
            self.compilation_config = CompilationConfig(**self.compilation_config)
617
618
        if isinstance(self.attention_config, dict):
            self.attention_config = AttentionConfig(**self.attention_config)
619
620
        if isinstance(self.kernel_config, dict):
            self.kernel_config = KernelConfig(**self.kernel_config)
621
        if isinstance(self.eplb_config, dict):
622
            self.eplb_config = EPLBConfig(**self.eplb_config)
623
624
625
626
        if isinstance(self.weight_transfer_config, dict):
            self.weight_transfer_config = WeightTransferConfig(
                **self.weight_transfer_config
            )
627
        # Setup plugins
628
        from vllm.plugins import load_general_plugins
629

630
        load_general_plugins()
631
        # when use hf offline,replace model and tokenizer id to local model path
632
633
634
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
            if model_id is not self.model:
                logger.info(
                    "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                    model_id,
                    self.model,
                )
            if self.tokenizer is not None:
                tokenizer_id = self.tokenizer
                self.tokenizer = get_model_path(self.tokenizer, self.tokenizer_revision)
                if tokenizer_id is not self.tokenizer:
                    logger.info(
                        "HF_HUB_OFFLINE is True, replace tokenizer_id [%s] "
                        "to tokenizer_path [%s]",
                        tokenizer_id,
                        self.tokenizer,
                    )
651
652

    @staticmethod
653
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
654
        """Shared CLI arguments for vLLM engine."""
655

656
        # Model arguments
657
658
659
660
661
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
662
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
663
            model_group.add_argument("--model", **model_kwargs["model"])
664
665
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
666
667
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
668
669
670
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
671
672
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
673
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
674
675
676
677
678
679
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
680
        model_group.add_argument("--revision", **model_kwargs["revision"])
681
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
682
683
684
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
685
686
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
687
688
689
690
        model_group.add_argument(
            "--allow-deprecated-quantization",
            **model_kwargs["allow_deprecated_quantization"],
        )
691
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
692
693
694
695
        model_group.add_argument(
            "--enable-return-routed-experts",
            **model_kwargs["enable_return_routed_experts"],
        )
696
697
698
699
700
701
702
703
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
704
705
706
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
707
708
709
710
711
712
713
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
714
715
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
735
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
736
737
738
739
740
741
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
742
743
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
744
        )
745

746
747
748
749
750
751
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
752
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
753
754
755
756
757
758
759
760
761
762
763
764
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
765

766
767
768
769
770
771
772
773
774
775
        # Attention arguments
        attention_kwargs = get_kwargs(AttentionConfig)
        attention_group = parser.add_argument_group(
            title="AttentionConfig",
            description=AttentionConfig.__doc__,
        )
        attention_group.add_argument(
            "--attention-backend", **attention_kwargs["backend"]
        )

776
777
778
779
780
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
781
        )
782
        structured_outputs_group.add_argument(
783
            "--reasoning-parser",
784
            # Choices need to be validated after parsing to include plugins
785
786
            **structured_outputs_kwargs["reasoning_parser"],
        )
787
788
789
790
        structured_outputs_group.add_argument(
            "--reasoning-parser-plugin",
            **structured_outputs_kwargs["reasoning_parser_plugin"],
        )
791

792
        # Parallel arguments
793
794
795
796
797
798
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
799
            "--distributed-executor-backend",
800
801
            **parallel_kwargs["distributed_executor_backend"],
        )
802
        parallel_group.add_argument(
803
804
805
806
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
807
808
809
810
        parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
        parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
        parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
        parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
811
        parallel_group.add_argument(
812
813
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
814
        parallel_group.add_argument(
815
816
817
818
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
819
820
821
822
        parallel_group.add_argument(
            "--dcp-kv-cache-interleave-size",
            **parallel_kwargs["dcp_kv_cache_interleave_size"],
        )
823
824
825
826
827
828
829
830
831
        parallel_group.add_argument(
            "--cp-kv-cache-interleave-size",
            **parallel_kwargs["cp_kv_cache_interleave_size"],
        )
        parallel_group.add_argument(
            "--prefill-context-parallel-size",
            "-pcp",
            **parallel_kwargs["prefill_context_parallel_size"],
        )
832
833
834
835
836
837
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
838
            type=int,
839
840
841
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
842
        parallel_group.add_argument(
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
873
        parallel_group.add_argument(
874
875
876
877
878
879
880
881
            "--data-parallel-hybrid-lb",
            "-dph",
            **parallel_kwargs["data_parallel_hybrid_lb"],
        )
        parallel_group.add_argument(
            "--data-parallel-external-lb",
            "-dpe",
            **parallel_kwargs["data_parallel_external_lb"],
882
883
        )
        parallel_group.add_argument(
884
885
886
            "--enable-expert-parallel",
            "-ep",
            **parallel_kwargs["enable_expert_parallel"],
887
        )
888
889
890
        parallel_group.add_argument(
            "--all2all-backend", **parallel_kwargs["all2all_backend"]
        )
891
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
892
893
894
895
        parallel_group.add_argument(
            "--ubatch-size",
            **parallel_kwargs["ubatch_size"],
        )
896
897
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
898
899
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
900
901
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
902
903
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
904
905
906
907
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
908
909
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
910
911
        parallel_group.add_argument(
            "--expert-placement-strategy",
912
913
            **parallel_kwargs["expert_placement_strategy"],
        )
914

915
        parallel_group.add_argument(
916
            "--max-parallel-loading-workers",
917
918
            **parallel_kwargs["max_parallel_loading_workers"],
        )
919
        parallel_group.add_argument(
920
921
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
922
        parallel_group.add_argument(
923
            "--disable-custom-all-reduce",
924
925
926
927
928
929
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
930

931
932
933
934
935
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
936
        )
937
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
938
939
940
941
942
943
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
944
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
945
946
947
948
949
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
950
951
952
953
954
            "--enable-prefix-caching",
            **{
                **cache_kwargs["enable_prefix_caching"],
                "default": None,
            },
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
971
972
973
        cache_group.add_argument(
            "--mamba-block-size", **cache_kwargs["mamba_block_size"]
        )
974
975
976
        cache_group.add_argument(
            "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
        )
977
978
979
980
981
982
        cache_group.add_argument(
            "--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
        )
        cache_group.add_argument(
            "--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
        )
983

984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
        # Model weight offload related configs
        offload_kwargs = get_kwargs(OffloadConfig)
        uva_kwargs = get_kwargs(UVAOffloadConfig)
        prefetch_kwargs = get_kwargs(PrefetchOffloadConfig)
        offload_group = parser.add_argument_group(
            title="OffloadConfig",
            description=OffloadConfig.__doc__,
        )
        offload_group.add_argument(
            "--offload-backend", **offload_kwargs["offload_backend"]
        )
        offload_group.add_argument("--cpu-offload-gb", **uva_kwargs["cpu_offload_gb"])
        offload_group.add_argument(
            "--cpu-offload-params", **uva_kwargs["cpu_offload_params"]
        )
        offload_group.add_argument(
            "--offload-group-size",
            **prefetch_kwargs["offload_group_size"],
        )
        offload_group.add_argument(
            "--offload-num-in-group",
            **prefetch_kwargs["offload_num_in_group"],
        )
        offload_group.add_argument(
            "--offload-prefetch-step",
            **prefetch_kwargs["offload_prefetch_step"],
        )
        offload_group.add_argument(
            "--offload-params", **prefetch_kwargs["offload_params"]
        )

1015
        # Multimodal related configs
1016
1017
1018
1019
1020
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
1021
1022
1023
        multimodal_group.add_argument(
            "--language-model-only", **multimodal_kwargs["language_model_only"]
        )
1024
        multimodal_group.add_argument(
1025
1026
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
1027
1028
1029
        multimodal_group.add_argument(
            "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
        )
1030
1031
1032
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
1033
1034
1035
1036
1037
1038
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
1039
        multimodal_group.add_argument(
1040
1041
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
1042
1043
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
1044
1045
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
1046
1047
1048
        multimodal_group.add_argument(
            "--mm-encoder-only", **multimodal_kwargs["mm_encoder_only"]
        )
1049
        multimodal_group.add_argument(
1050
1051
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
1052
1053
1054
1055
        multimodal_group.add_argument(
            "--mm-encoder-attn-backend",
            **multimodal_kwargs["mm_encoder_attn_backend"],
        )
1056
1057
1058
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
1059
        multimodal_group.add_argument(
1060
1061
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
1062

1063
        multimodal_group.add_argument(
1064
1065
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
1066

1067
        # LoRA related configs
1068
1069
1070
1071
1072
1073
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
1074
            "--enable-lora",
1075
            action=argparse.BooleanOptionalAction,
1076
1077
            help="If True, enable handling of LoRA adapters.",
        )
1078
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
1079
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
1080
        lora_group.add_argument(
1081
            "--lora-dtype",
1082
1083
            **lora_kwargs["lora_dtype"],
        )
1084
1085
1086
1087
        lora_group.add_argument(
            "--enable-tower-connector-lora",
            **lora_kwargs["enable_tower_connector_lora"],
        )
1088
1089
1090
1091
1092
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
1093
1094
1095
        lora_group.add_argument(
            "--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
        )
1096

1097
1098
1099
1100
1101
1102
1103
1104
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
1105
1106
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
1107
        observability_group.add_argument(
1108
1109
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
1110
1111
1112
1113
1114
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
1115
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
1116
1117
1118
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
1119
1120
            **observability_kwargs["collect_detailed_traces"],
        )
1121
1122
1123
1124
1125
1126
1127
        observability_group.add_argument(
            "--kv-cache-metrics", **observability_kwargs["kv_cache_metrics"]
        )
        observability_group.add_argument(
            "--kv-cache-metrics-sample",
            **observability_kwargs["kv_cache_metrics_sample"],
        )
1128
1129
1130
1131
        observability_group.add_argument(
            "--cudagraph-metrics",
            **observability_kwargs["cudagraph_metrics"],
        )
1132
1133
1134
1135
        observability_group.add_argument(
            "--enable-layerwise-nvtx-tracing",
            **observability_kwargs["enable_layerwise_nvtx_tracing"],
        )
1136
1137
1138
1139
        observability_group.add_argument(
            "--enable-mfu-metrics",
            **observability_kwargs["enable_mfu_metrics"],
        )
1140
1141
1142
1143
        observability_group.add_argument(
            "--enable-logging-iteration-details",
            **observability_kwargs["enable_logging_iteration_details"],
        )
1144

1145
1146
1147
1148
1149
1150
1151
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
1152
1153
1154
1155
1156
            "--max-num-batched-tokens",
            **{
                **scheduler_kwargs["max_num_batched_tokens"],
                "default": None,
            },
1157
        )
1158
        scheduler_group.add_argument(
1159
1160
1161
1162
1163
            "--max-num-seqs",
            **{
                **scheduler_kwargs["max_num_seqs"],
                "default": None,
            },
1164
1165
1166
1167
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
1168
1169
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
1170
1171
            **scheduler_kwargs["max_long_partial_prefills"],
        )
1172
1173
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
1174
1175
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
1176
1177
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
1178
        scheduler_group.add_argument(
1179
1180
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
1181
        scheduler_group.add_argument(
1182
1183
1184
1185
1186
            "--enable-chunked-prefill",
            **{
                **scheduler_kwargs["enable_chunked_prefill"],
                "default": None,
            },
1187
1188
1189
1190
1191
1192
1193
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1194
1195
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1196
1197
1198
1199
1200
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1201
1202
1203
        scheduler_group.add_argument(
            "--stream-interval", **scheduler_kwargs["stream_interval"]
        )
1204

1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
        # Compilation arguments
        compilation_kwargs = get_kwargs(CompilationConfig)
        compilation_group = parser.add_argument_group(
            title="CompilationConfig",
            description=CompilationConfig.__doc__,
        )
        compilation_group.add_argument(
            "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
        )
        compilation_group.add_argument(
            "--max-cudagraph-capture-size",
            **compilation_kwargs["max_cudagraph_capture_size"],
        )

1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
        # Kernel arguments
        kernel_kwargs = get_kwargs(KernelConfig)
        kernel_group = parser.add_argument_group(
            title="KernelConfig",
            description=KernelConfig.__doc__,
        )
        kernel_group.add_argument(
            "--enable-flashinfer-autotune",
            **kernel_kwargs["enable_flashinfer_autotune"],
        )

1230
        # vLLM arguments
1231
        vllm_kwargs = get_kwargs(VllmConfig)
1232
1233
1234
1235
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1236
1237
1238
1239
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1240
1241
1242
1243
1244
1245
1246
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
1247
1248
1249
        vllm_group.add_argument(
            "--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
        )
1250
        vllm_group.add_argument(
1251
            "--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
1252
        )
1253
1254
1255
        vllm_group.add_argument(
            "--attention-config", "-ac", **vllm_kwargs["attention_config"]
        )
1256
        vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
1257
1258
1259
1260
1261
1262
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1263
        vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
1264
1265
1266
        vllm_group.add_argument(
            "--optimization-level", **vllm_kwargs["optimization_level"]
        )
1267
1268
1269
        vllm_group.add_argument(
            "--weight-transfer-config", **vllm_kwargs["weight_transfer_config"]
        )
1270

1271
        # Other arguments
1272
1273
1274
1275
1276
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1277

1278
1279
1280
1281
1282
1283
        parser.add_argument(
            "--aggregate-engine-logging",
            action="store_true",
            help="Log aggregate rather than per-engine statistics "
            "when using data parallelism.",
        )
1284
1285
1286
1287
1288
1289
1290
1291

        parser.add_argument(
            "--fail-on-environ-validation",
            help="If set, the engine will raise an error if "
            "environment validation fails.",
            default=False,
            action=argparse.BooleanOptionalAction,
        )
1292
        return parser
1293
1294

    @classmethod
1295
    def from_cli_args(cls, args: argparse.Namespace):
1296
1297
1298
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1299
1300
1301
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1302
        return engine_args
1303

1304
    def create_model_config(self) -> ModelConfig:
1305
1306
        # gguf file needs a specific model loader
        if is_gguf(self.model):
1307
1308
            self.quantization = self.load_format = "gguf"

1309
1310
1311
1312
1313
1314
1315
        if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
            logger.warning(
                "The global random seed is set to %d. Since "
                "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
                "affect the random state of the Python process that "
                "launched vLLM.",
                self.seed,
1316
1317
            )

1318
        return ModelConfig(
1319
            model=self.model,
1320
            model_weights=self.model_weights,
1321
            hf_config_path=self.hf_config_path,
1322
1323
            runner=self.runner,
            convert=self.convert,
1324
            tokenizer=self.tokenizer,  # type: ignore[arg-type]
1325
            tokenizer_mode=self.tokenizer_mode,
1326
            trust_remote_code=self.trust_remote_code,
1327
1328
            allowed_local_media_path=self.allowed_local_media_path,
            allowed_media_domains=self.allowed_media_domains,
1329
1330
1331
1332
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
1333
            hf_token=self.hf_token,
1334
            hf_overrides=self.hf_overrides,
1335
            tokenizer_revision=self.tokenizer_revision,
1336
1337
            max_model_len=self.max_model_len,
            quantization=self.quantization,
1338
            allow_deprecated_quantization=self.allow_deprecated_quantization,
1339
            enforce_eager=self.enforce_eager,
1340
            enable_return_routed_experts=self.enable_return_routed_experts,
1341
            max_logprobs=self.max_logprobs,
1342
            logprobs_mode=self.logprobs_mode,
1343
            disable_sliding_window=self.disable_sliding_window,
1344
            disable_cascade_attn=self.disable_cascade_attn,
1345
            skip_tokenizer_init=self.skip_tokenizer_init,
1346
            enable_prompt_embeds=self.enable_prompt_embeds,
1347
            served_model_name=self.served_model_name,
1348
            language_model_only=self.language_model_only,
1349
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1350
            enable_mm_embeds=self.enable_mm_embeds,
1351
            interleave_mm_strings=self.interleave_mm_strings,
1352
            media_io_kwargs=self.media_io_kwargs,
1353
            skip_mm_profiling=self.skip_mm_profiling,
1354
            config_format=self.config_format,
1355
            mm_processor_kwargs=self.mm_processor_kwargs,
1356
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1357
            mm_processor_cache_type=self.mm_processor_cache_type,
1358
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1359
            mm_encoder_only=self.mm_encoder_only,
1360
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1361
            mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1362
            pooler_config=self.pooler_config,
1363
            generation_config=self.generation_config,
1364
            override_generation_config=self.override_generation_config,
1365
            enable_sleep_mode=self.enable_sleep_mode,
1366
            model_impl=self.model_impl,
1367
            override_attention_dtype=self.override_attention_dtype,
1368
            logits_processors=self.logits_processors,
1369
            video_pruning_rate=self.video_pruning_rate,
1370
            io_processor_plugin=self.io_processor_plugin,
1371
        )
1372

1373
    def validate_tensorizer_args(self):
1374
1375
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1376
1377
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1378
1379
1380
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1381

1382
    def create_load_config(self) -> LoadConfig:
1383
1384
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1385

1386
1387
1388
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1389
1390
                    self.model_loader_extra_config.to_serializable()
                )
1391
            self.model_loader_extra_config["tensorizer_config"] = {}
1392
1393
1394
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1395
            self.validate_tensorizer_args()
1396

1397
1398
1399
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1400
            safetensors_load_strategy=self.safetensors_load_strategy,
1401
1402
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1403
            use_tqdm_on_load=self.use_tqdm_on_load,
1404
            pt_load_map_location=self.pt_load_map_location,
1405
        )
1406

1407
1408
1409
1410
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
1411
    ) -> SpeculativeConfig | None:
1412
1413
1414
1415
1416
1417
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1418
        dictionary from the engine.
1419
1420
        """
        if self.speculative_config is None:
1421
            return None
1422

1423
1424
1425
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1426
1427
1428
1429
1430
1431
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
            }
        )
1432
        return SpeculativeConfig(**self.speculative_config)
1433

1434
1435
    def create_engine_config(
        self,
1436
        usage_context: UsageContext | None = None,
1437
        headless: bool = False,
1438
1439
1440
1441
    ) -> VllmConfig:
        """
        Create the VllmConfig.

1442
        NOTE: If VllmConfig is incompatible, we raise an error.
1443
        """
1444
        current_platform.pre_register_and_update()
1445

1446
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1447

1448
1449
        envs.validate_environ(self.fail_on_environ_validation)

1450
1451
        # Check if the model is a speculator and override model/tokenizer/config
        # BEFORE creating ModelConfig, so the config is created with the target model
1452
1453
1454
1455
        # Skip speculator detection for cloud storage models (eg: S3, GCS) since
        # HuggingFace cannot load configs directly from S3 URLs. S3 models can still
        # use speculators with explicit --speculative-config.
        if not is_cloud_storage(self.model):
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
            (self.model, self.tokenizer, self.speculative_config) = (
                maybe_override_with_speculators(
                    model=self.model,
                    tokenizer=self.tokenizer,
                    revision=self.revision,
                    trust_remote_code=self.trust_remote_code,
                    vllm_speculative_config=self.speculative_config,
                )
            )

1466
        model_config = self.create_model_config()
1467
        self.model = model_config.model
1468
        self.model_weights = model_config.model_weights
1469
1470
        self.tokenizer = model_config.tokenizer

1471
        self._check_feature_supported()
1472
1473
1474
1475
        self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
        self._set_default_max_num_seqs_and_batched_tokens_args(
            usage_context, model_config
        )
1476

1477
        sliding_window: int | None = None
1478
1479
1480
1481
1482
1483
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1484
1485
1486
1487
1488
        # Resolve "auto" kv_cache_dtype to actual value from model config
        resolved_cache_dtype = resolve_kv_cache_dtype_string(
            self.kv_cache_dtype, model_config
        )

1489
1490
1491
1492
        assert self.enable_prefix_caching is not None, (
            "enable_prefix_caching must be set by this point"
        )

1493
        cache_config = CacheConfig(
1494
            block_size=self.block_size,
1495
            gpu_memory_utilization=self.gpu_memory_utilization,
1496
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1497
            swap_space=self.swap_space,
1498
            cache_dtype=resolved_cache_dtype,  # type: ignore[arg-type]
1499
            is_attention_free=model_config.is_attention_free,
1500
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1501
            sliding_window=sliding_window,
1502
            enable_prefix_caching=self.enable_prefix_caching,
1503
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1504
            calculate_kv_scales=self.calculate_kv_scales,
1505
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1506
1507
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1508
            mamba_block_size=self.mamba_block_size,
1509
            mamba_cache_mode=self.mamba_cache_mode,
1510
1511
            kv_offloading_size=self.kv_offloading_size,
            kv_offloading_backend=self.kv_offloading_backend,
1512
        )
1513

1514
1515
1516
1517
1518
1519
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1520

1521
            ray_runtime_env = ray.get_runtime_context().runtime_env
1522
1523
1524
1525
1526
1527
1528
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1529

1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1541
        assert not headless or not self.data_parallel_hybrid_lb, (
1542
1543
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
        assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
            "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
        )
        assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
            "nnodes > 1 is only supported with data_parallel_backend=mp"
        )
        inferred_data_parallel_rank = 0
        if self.nnodes > 1:
            world_size = (
                self.data_parallel_size
                * self.pipeline_parallel_size
                * self.tensor_parallel_size
            )
            world_size_within_dp = (
                self.pipeline_parallel_size * self.tensor_parallel_size
            )
            local_world_size = world_size // self.nnodes
            assert world_size % self.nnodes == 0, (
                f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
            )
            assert self.node_rank < self.nnodes, (
                f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
            )
            inferred_data_parallel_rank = (
                self.node_rank * local_world_size
            ) // world_size_within_dp
            if self.data_parallel_size > 1 and self.data_parallel_external_lb:
                self.data_parallel_rank = inferred_data_parallel_rank
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d for external lb",
                    self.data_parallel_rank,
                    self.node_rank,
                )
            elif self.data_parallel_size_local is None:
                # Infer data parallel size local for internal dplb:
                self.data_parallel_size_local = max(
                    local_world_size // world_size_within_dp, 1
                )
        data_parallel_external_lb = (
            self.data_parallel_external_lb or self.data_parallel_rank is not None
        )
1585
        # Local DP rank = 1, use pure-external LB.
1586
        if data_parallel_external_lb:
1587
            assert self.data_parallel_rank is not None, (
1588
                "data_parallel_rank or node_rank must be specified if "
1589
1590
                "data_parallel_external_lb is enable."
            )
1591
            assert self.data_parallel_size_local in (1, None), (
1592
1593
                "data_parallel_size_local must be 1 or None when data_parallel_rank "
                "is set"
1594
            )
1595
            data_parallel_size_local = 1
1596
1597
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1598
1599
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1600
1601
1602
1603
1604
1605
1606

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
1607
1608
1609
1610
1611
                logger.warning(
                    "data_parallel_hybrid_lb is not eligible when "
                    "data_parallel_size_local = 1, autoswitch to "
                    "data_parallel_external_lb."
                )
1612
1613
1614
1615
1616
1617
1618
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

1619
1620
1621
1622
1623
1624
1625
1626
1627
            self.data_parallel_rank = (
                self.data_parallel_start_rank or inferred_data_parallel_rank
            )
            if self.nnodes > 1:
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d",
                    self.data_parallel_rank,
                    self.node_rank,
                )
1628
        else:
1629
            assert not self.data_parallel_hybrid_lb, (
1630
1631
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1632

1633
1634
1635
1636
1637
1638
1639
1640
1641
            if self.data_parallel_backend == "ray" and (
                envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
            ):
                # Data parallel size defaults to 1 if DP ranks are spanning
                # multiple nodes
                data_parallel_size_local = 1
            else:
                # Otherwise local DP size defaults to global DP size if not set
                data_parallel_size_local = self.data_parallel_size
1642
1643
1644

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1645
1646
1647
1648
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1649
1650
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1651
1652
1653
1654
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1655
1656
                    self.data_parallel_backend,
                )
1657
1658
1659
                data_parallel_address = (
                    self.master_addr or ParallelConfig.data_parallel_master_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1660
1661
        else:
            data_parallel_address = self.data_parallel_address
1662
1663
1664

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1665
        data_parallel_rpc_port = (
1666
            self.data_parallel_rpc_port
1667
1668
1669
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1670

1671
1672
1673
1674
        if self.tokens_only and not model_config.skip_tokenizer_init:
            model_config.skip_tokenizer_init = True
            logger.info("Skipping tokenizer initialization for tokens-only mode.")

1675
        parallel_config = ParallelConfig(
1676
1677
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1678
            prefill_context_parallel_size=self.prefill_context_parallel_size,
1679
            data_parallel_size=self.data_parallel_size,
1680
1681
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1682
            data_parallel_size_local=data_parallel_size_local,
1683
1684
1685
1686
            master_addr=self.master_addr,
            master_port=self.master_port,
            nnodes=self.nnodes,
            node_rank=self.node_rank,
1687
1688
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1689
            data_parallel_backend=self.data_parallel_backend,
1690
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1691
            is_moe_model=model_config.is_moe,
1692
            enable_expert_parallel=self.enable_expert_parallel,
1693
            all2all_backend=self.all2all_backend,
1694
            enable_dbo=self.enable_dbo,
1695
            ubatch_size=self.ubatch_size,
1696
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1697
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1698
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1699
            enable_eplb=self.enable_eplb,
1700
            eplb_config=self.eplb_config,
1701
            expert_placement_strategy=self.expert_placement_strategy,
1702
1703
1704
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1705
            ray_runtime_env=ray_runtime_env,
1706
            placement_group=placement_group,
1707
1708
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1709
            worker_extension_cls=self.worker_extension_cls,
1710
            decode_context_parallel_size=self.decode_context_parallel_size,
1711
            dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
1712
            cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
1713
1714
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1715
        )
1716

1717
        speculative_config = self.create_speculative_config(
1718
1719
1720
1721
            target_model_config=model_config,
            target_parallel_config=parallel_config,
        )

1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
        assert self.max_num_batched_tokens is not None, (
            "max_num_batched_tokens must be set by this point"
        )
        assert self.max_num_seqs is not None, "max_num_seqs must be set by this point"
        assert self.enable_chunked_prefill is not None, (
            "enable_chunked_prefill must be set by this point"
        )
        assert model_config.max_model_len is not None, (
            "max_model_len must be set by this point"
        )
1732
        scheduler_config = SchedulerConfig(
1733
            runner_type=model_config.runner_type,
1734
1735
1736
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1737
            enable_chunked_prefill=self.enable_chunked_prefill,
1738
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1739
            is_multimodal_model=model_config.is_multimodal_model,
1740
            is_encoder_decoder=model_config.is_encoder_decoder,
1741
            policy=self.scheduling_policy,
1742
            scheduler_cls=self.scheduler_cls,
1743
1744
1745
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1746
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1747
            async_scheduling=self.async_scheduling,
1748
            stream_interval=self.stream_interval,
1749
        )
1750

1751
1752
1753
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_dtype=self.lora_dtype,
1764
                enable_tower_connector_lora=self.enable_tower_connector_lora,
1765
                specialize_active_lora=self.specialize_active_lora,
1766
1767
1768
1769
1770
1771
1772
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1773

1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
        if (
            lora_config is not None
            and speculative_config is not None
            and scheduler_config.max_num_batched_tokens
            < (
                scheduler_config.max_num_seqs
                * (speculative_config.num_speculative_tokens + 1)
            )
        ):
            raise ValueError(
                "Consider increasing max_num_batched_tokens or "
                "decreasing num_speculative_tokens"
            )

1788
1789
1790
1791
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1792
1793
1794
1795
1796
1797
1798
1799
        # Attention config overrides
        attention_config = copy.deepcopy(self.attention_config)
        if self.attention_backend is not None:
            if attention_config.backend is not None:
                raise ValueError(
                    "attention_backend and attention_config.backend "
                    "are mutually exclusive"
                )
1800
1801
1802
1803
1804
1805
1806
            # Convert string to enum if needed (CLI parsing returns a string)
            if isinstance(self.attention_backend, str):
                attention_config.backend = AttentionBackendEnum[
                    self.attention_backend.upper()
                ]
            else:
                attention_config.backend = self.attention_backend
1807

1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
        # Kernel config overrides
        kernel_config = copy.deepcopy(self.kernel_config)
        if self.enable_flashinfer_autotune is not None:
            if kernel_config.enable_flashinfer_autotune is not None:
                raise ValueError(
                    "enable_flashinfer_autotune and "
                    "kernel_config.enable_flashinfer_autotune "
                    "are mutually exclusive"
                )
            kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune

1819
        load_config = self.create_load_config()
1820

1821
1822
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1823
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1824

1825
1826
1827
1828
1829
        if self.reasoning_parser_plugin:
            self.structured_outputs_config.reasoning_parser_plugin = (
                self.reasoning_parser_plugin
            )

1830
        observability_config = ObservabilityConfig(
1831
            show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
1832
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1833
            collect_detailed_traces=self.collect_detailed_traces,
1834
1835
            kv_cache_metrics=self.kv_cache_metrics,
            kv_cache_metrics_sample=self.kv_cache_metrics_sample,
1836
            cudagraph_metrics=self.cudagraph_metrics,
1837
            enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
1838
            enable_mfu_metrics=self.enable_mfu_metrics,
1839
            enable_mm_processor_stats=self.enable_mm_processor_stats,
1840
            enable_logging_iteration_details=self.enable_logging_iteration_details,
1841
        )
1842

1843
        # Compilation config overrides
1844
        compilation_config = copy.deepcopy(self.compilation_config)
1845
        if self.cudagraph_capture_sizes is not None:
1846
            if compilation_config.cudagraph_capture_sizes is not None:
1847
1848
1849
1850
                raise ValueError(
                    "cudagraph_capture_sizes and compilation_config."
                    "cudagraph_capture_sizes are mutually exclusive"
                )
1851
            compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes
1852
        if self.max_cudagraph_capture_size is not None:
1853
            if compilation_config.max_cudagraph_capture_size is not None:
1854
1855
1856
1857
                raise ValueError(
                    "max_cudagraph_capture_size and compilation_config."
                    "max_cudagraph_capture_size are mutually exclusive"
                )
1858
            compilation_config.max_cudagraph_capture_size = (
1859
1860
                self.max_cudagraph_capture_size
            )
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875

        offload_config = OffloadConfig(
            offload_backend=self.offload_backend,
            uva=UVAOffloadConfig(
                cpu_offload_gb=self.cpu_offload_gb,
                cpu_offload_params=self.cpu_offload_params,
            ),
            prefetch=PrefetchOffloadConfig(
                offload_group_size=self.offload_group_size,
                offload_num_in_group=self.offload_num_in_group,
                offload_prefetch_step=self.offload_prefetch_step,
                offload_params=self.offload_params,
            ),
        )

1876
        config = VllmConfig(
1877
1878
1879
1880
1881
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
1882
            load_config=load_config,
1883
            offload_config=offload_config,
1884
            attention_config=attention_config,
1885
            kernel_config=kernel_config,
1886
1887
            lora_config=lora_config,
            speculative_config=speculative_config,
1888
            structured_outputs_config=self.structured_outputs_config,
1889
            observability_config=observability_config,
1890
            compilation_config=compilation_config,
1891
            kv_transfer_config=self.kv_transfer_config,
1892
            kv_events_config=self.kv_events_config,
1893
            ec_transfer_config=self.ec_transfer_config,
1894
            profiler_config=self.profiler_config,
1895
            additional_config=self.additional_config,
1896
            optimization_level=self.optimization_level,
1897
            weight_transfer_config=self.weight_transfer_config,
1898
        )
1899

1900
1901
        return config

1902
    def _check_feature_supported(self):
1903
        """Raise an error if the feature is not supported."""
1904
        # No Concurrent Partial Prefills so far.
1905
1906
1907
1908
1909
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
1910
            _raise_unsupported_error(feature_name="Concurrent Partial Prefill")
1911

1912
        if self.pipeline_parallel_size > 1:
1913
1914
1915
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1916
            if not supports_pp and self.distributed_executor_backend not in (
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
1927
                _raise_unsupported_error(feature_name=name)
1928

1929
1930
1931
1932
1933
1934
    @classmethod
    def get_batch_defaults(
        cls,
        world_size: int,
    ) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
        from vllm.usage.usage_lib import UsageContext
1935

1936
1937
        default_max_num_batched_tokens: dict[UsageContext | None, int]
        default_max_num_seqs: dict[UsageContext | None, int]
1938

1939
1940
        # When no user override, set the default values based on the usage
        # context.
1941
        # Use different default values for different hardware.
1942
1943
1944
1945
1946
1947
1948

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1949
            device_memory = current_platform.get_device_total_memory()
1950
            device_name = current_platform.get_device_name().lower()
1951
1952
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1953
            device_memory = 0
1954
            device_name = ""
1955

1956
1957
1958
1959
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1960
            # For GPUs like H100 and MI300x, use larger default values.
1961
1962
1963
1964
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1965
1966
1967
1968
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1969
1970
1971
1972
1973
1974
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1975
1976
1977
1978
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1979

1980
1981
        # tpu specific default values.
        if current_platform.is_tpu():
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
            chip_name = current_platform.get_device_name()

            if chip_name == "V6E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 2048,
                    UsageContext.OPENAI_API_SERVER: 1024,
                }
            elif chip_name == "V5E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 1024,
                    UsageContext.OPENAI_API_SERVER: 512,
                }
            elif chip_name == "V5P":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 512,
                    UsageContext.OPENAI_API_SERVER: 256,
                }
1999

2000
2001
2002
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
2003
2004
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
2005
2006
            }
            default_max_num_seqs = {
2007
2008
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
2009
2010
            }

2011
2012
        return default_max_num_batched_tokens, default_max_num_seqs

2013
2014
    def _set_default_chunked_prefill_and_prefix_caching_args(
        self, model_config: ModelConfig
2015
    ) -> None:
2016
2017
        default_chunked_prefill = model_config.is_chunked_prefill_supported
        default_prefix_caching = model_config.is_prefix_caching_supported
2018
2019
2020
2021
2022
2023
2024
2025

        if self.enable_chunked_prefill is None:
            self.enable_chunked_prefill = default_chunked_prefill

            logger.debug(
                "%s chunked prefill by default",
                "Enabling" if default_chunked_prefill else "Disabling",
            )
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
        elif (
            model_config.runner_type == "generate"
            and not self.enable_chunked_prefill
            and default_chunked_prefill
        ):
            logger.warning_once(
                "This model does not officially support disabling chunked prefill. "
                "Disabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
                scope="local",
            )
2037
2038
2039
2040
        elif (
            model_config.runner_type == "pooling"
            and self.enable_chunked_prefill
            and not default_chunked_prefill
2041
        ):
2042
            logger.warning_once(
2043
2044
2045
                "This model does not officially support chunked prefill. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
2046
                scope="local",
2047
2048
2049
2050
2051
            )

        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = default_prefix_caching

2052
            logger.debug(
2053
2054
2055
2056
2057
2058
2059
2060
                "%s prefix caching by default",
                "Enabling" if default_prefix_caching else "Disabling",
            )
        elif (
            model_config.runner_type == "pooling"
            and self.enable_prefix_caching
            and not default_prefix_caching
        ):
2061
            logger.warning_once(
2062
2063
2064
                "This model does not officially support prefix caching. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
2065
                scope="local",
2066
2067
            )

2068
        # Disable chunked prefill and prefix caching for:
2069
        # POWER (ppc64le)/RISCV CPUs in V1
2070
2071
2072
2073
2074
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.RISCV,
        ):
            logger.info(
2075
                "Chunked prefill is not supported for POWER, "
2076
                "and RISC-V CPUs; "
2077
2078
2079
2080
                "disabling it for V1 backend."
            )
            self.enable_chunked_prefill = False
            logger.info(
2081
                "Prefix caching is not supported for POWER, "
2082
                "and RISC-V CPUs; "
2083
2084
2085
2086
2087
                "disabling it for V1 backend."
            )
            self.enable_prefix_caching = False

    def _set_default_max_num_seqs_and_batched_tokens_args(
2088
2089
2090
        self,
        usage_context: UsageContext | None,
        model_config: ModelConfig,
2091
    ):
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
        world_size = self.pipeline_parallel_size * self.tensor_parallel_size
        (
            default_max_num_batched_tokens,
            default_max_num_seqs,
        ) = self.get_batch_defaults(world_size)

        orig_max_num_batched_tokens = self.max_num_batched_tokens
        orig_max_num_seqs = self.max_num_seqs

        if self.max_num_batched_tokens is None:
            self.max_num_batched_tokens = default_max_num_batched_tokens.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )

        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
            )

        if orig_max_num_batched_tokens is None:
2114
2115
2116
            assert model_config.max_model_len is not None, (
                "max_model_len must be set by this point"
            )
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
            if not self.enable_chunked_prefill:
                # If max_model_len is too short, use the default for higher throughput.
                self.max_num_batched_tokens = max(
                    model_config.max_model_len,
                    self.max_num_batched_tokens,
                )

            # When using default settings,
            # Ensure max_num_batched_tokens does not exceed model limit.
            # Some models (e.g., Whisper) have embeddings tied to max length.
            self.max_num_batched_tokens = min(
                self.max_num_seqs * model_config.max_model_len,
2129
2130
                self.max_num_batched_tokens,
            )
2131

2132
2133
2134
2135
            logger.debug(
                "Defaulting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens,
                usage_context.value if usage_context else None,
2136
            )
2137

2138
2139
2140
2141
        if orig_max_num_seqs is None:
            assert self.max_num_batched_tokens is not None  # For type checking
            self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

2142
            logger.debug(
2143
                "Defaulting max_num_seqs to %d for %s usage context.",
2144
                self.max_num_seqs,
2145
                usage_context.value if usage_context else None,
2146
            )
2147

2148

2149
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
2150
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
2151
    """Arguments for asynchronous vLLM engine."""
2152

2153
2154
    enable_log_requests: bool = False

2155
    @staticmethod
2156
2157
2158
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
2159
        # Initialize plugin to update the parser, for example, The plugin may
2160
        # add a new kind of quantization method to --quantization argument or
2161
2162
        # a new device to --device argument.
        load_general_plugins()
2163
2164
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
2178
        current_platform.pre_register_and_update(parser)
2179
        return parser
2180
2181


2182
2183
2184
2185
2186
2187
def _raise_unsupported_error(feature_name: str):
    msg = (
        f"{feature_name} is not supported. We recommend to "
        f"remove {feature_name} from your config."
    )
    raise NotImplementedError(msg)
2188
2189


2190
def human_readable_int(value: str) -> int:
2191
2192
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
2193

2194
2195
2196
2197
2198
2199
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
2200

2201
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
2202
2203
    if match:
        decimal_multiplier = {
2204
2205
2206
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
2207
            "t": 10**12,
2208
2209
        }
        binary_multiplier = {
2210
2211
2212
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
2213
            "T": 2**40,
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
2226
2227
2228
2229
2230
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
2231
2232
2233

    # Regular plain number.
    return int(value)
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252


def human_readable_int_or_auto(value: str) -> int:
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    Also accepts -1 or 'auto' as a special value for auto-detection.

    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    - '-1' or 'auto' -> -1 (special value for auto-detection)
    """
    value = value.strip()

    if value == "-1" or value.lower() == "auto":
        return -1

    return human_readable_int(value)