arg_utils.py 93.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from collections.abc import Callable
11
from dataclasses import MISSING, dataclass, fields, is_dataclass
12
from itertools import permutations
13
from types import UnionType
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Literal,
19
    TypeAlias,
20
21
22
23
24
25
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
26

27
import huggingface_hub
28
import regex as re
29
import torch
30
from pydantic import TypeAdapter, ValidationError
31
from pydantic.fields import FieldInfo
32
from typing_extensions import TypeIs
33

34
import vllm.envs as envs
35
from vllm.config import (
36
    AttentionConfig,
37
38
39
40
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
41
    ECTransferConfig,
42
    EPLBConfig,
43
    KernelConfig,
44
45
46
47
48
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
    ModelConfig,
49
    MultiModalConfig,
50
    ObservabilityConfig,
51
    OffloadConfig,
52
53
    ParallelConfig,
    PoolerConfig,
54
    PrefetchOffloadConfig,
55
    ProfilerConfig,
56
57
58
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
59
    UVAOffloadConfig,
60
    VllmConfig,
61
    WeightTransferConfig,
62
63
    get_attr_docs,
)
64
from vllm.config.cache import (
65
    BlockSize,
66
67
    CacheDType,
    KVOffloadingBackend,
68
    MambaCacheMode,
69
70
71
    MambaDType,
    PrefixCachingHashAlgo,
)
72
from vllm.config.device import Device
73
from vllm.config.kernel import MoEBackend
74
from vllm.config.lora import MaxLoRARanks
75
76
77
78
79
80
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
81
    TokenizerMode,
82
83
84
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.observability import DetailedTraceModules
85
86
87
from vllm.config.parallel import (
    All2AllBackend,
    DataParallelBackend,
88
    DCPCommBackend,
89
90
91
    DistributedExecutorBackend,
    ExpertPlacementStrategy,
)
92
from vllm.config.scheduler import SchedulerPolicy
93
from vllm.config.utils import get_field
94
from vllm.config.vllm import OptimizationLevel, PerformanceMode
95
from vllm.logger import init_logger, suppress_logging
96
from vllm.platforms import CpuArchEnum, current_platform
97
from vllm.plugins import load_general_plugins
98
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
99
100
101
102
from vllm.transformers_utils.config import (
    is_interleaved,
    maybe_override_with_speculators,
)
103
from vllm.transformers_utils.gguf_utils import is_gguf
104
from vllm.transformers_utils.repo_utils import get_model_path
105
from vllm.transformers_utils.utils import is_cloud_storage
106
from vllm.utils.argparse_utils import FlexibleArgumentParser
107
from vllm.utils.mem_constants import GiB_bytes
108
from vllm.utils.network_utils import get_ip
109
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
110
from vllm.v1.attention.backends.registry import AttentionBackendEnum
111
from vllm.v1.sample.logits_processor import LogitsProcessor
112

113
114
if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
115
    from vllm.model_executor.model_loader import LoadFormats
116
    from vllm.usage.usage_lib import UsageContext
117
    from vllm.v1.executor import Executor
118
else:
119
    Executor = Any
120
    QuantizationMethods = Any
121
    LoadFormats = Any
122
123
    UsageContext = Any

124

125
126
logger = init_logger(__name__)

127
128
# object is used to allow for special typing forms
T = TypeVar("T")
129
130
TypeHint: TypeAlias = type[Any] | object
TypeHintT: TypeAlias = type[T] | object
131

132

133
134
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
135
136
137
138
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
139
140
                f"Value {val} cannot be converted to {return_type}."
            ) from e
141

142
143
144
    return _parse_type


145
146
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]:
    def _optional_type(val: str) -> T | None:
147
148
149
150
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

151
    return _optional_type
152
153


154
def union_dict_and_str(val: str) -> str | dict[str, str] | None:
155
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
156
        return str(val)
157
    return optional_type(json.loads)(val)
158
159


160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


175
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
176
177
178
179
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
180
    type_hint = get_type(type_hints, Literal)
181
182
183
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
184
        raise ValueError(
185
            "All options must be of the same type. "
186
187
            f"Got {options} with types {[type(c) for c in options]}"
        )
188
189
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
190
191


192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]:
    type_hint = get_type(type_hints, type)
    types = get_args(type_hint)
    elem_type = types[0]

    # Handle Ellipsis
    assert all(t is elem_type for t in types if t is not Ellipsis), (
        f"All non-Ellipsis elements must be of the same type. Got {types}."
    )

    # Handle Union types
    if get_origin(elem_type) in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
        assert str in get_args(elem_type), (
            "If element can have multiple types, one must be 'str' "
            f"(i.e. 'list[int | str]'). Got {elem_type}."
        )
        elem_type = str

    return {
        "type": elem_type,
        "nargs": "+" if type is not tuple or Ellipsis in types else len(types),
    }


217
218
219
220
221
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


222
223
224
225
226
227
228
229
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
230
231
    elif origin in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
232
233
234
235
236
237
238
239
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


240
NEEDS_HELP = (
241
242
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
243
244
245
246
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


247
@functools.lru_cache(maxsize=30)
248
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
249
250
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
251
252
    kwargs = {}
    for field in fields(cls):
253
        # Get the set of possible types for the field
254
        type_hints: set[TypeHint] = get_type_hints(field.type)
255
256
257
258
259

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

260
        # Get the default value of the field
261
262
        if field.default is not MISSING:
            default = field.default
263
264
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
265
266
267
268
269
270
                if default.default_factory is None:
                    default = default.default
                else:
                    # VllmConfig's Fields have default_factory set to config classes.
                    # These could emit logs on init, which would be confusing.
                    with suppress_logging():
271
                        default = default.default_factory()  # type: ignore[call-arg]
272
        elif field.default_factory is not MISSING:
273
            default = field.default_factory()
274
275
276

        # Get the help text for the field
        name = field.name
277
        help = cls_docs.get(name, "").strip()
278
279
280
281
282
283
284
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
285
286
287
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
288
        if dataclass_cls is not None:
289
290
291
292
293
294
295
296

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
297
            kwargs[name]["help"] += f"\n\n{json_tip}"
298
        elif contains_type(type_hints, bool):
299
300
301
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
302
            kwargs[name].update(literal_to_kwargs(type_hints))
303
        elif contains_type(type_hints, tuple):
304
            kwargs[name].update(collection_to_kwargs(type_hints, tuple))
305
        elif contains_type(type_hints, list):
306
307
308
            kwargs[name].update(collection_to_kwargs(type_hints, list))
        elif contains_type(type_hints, set):
            kwargs[name].update(collection_to_kwargs(type_hints, set))
309
        elif contains_type(type_hints, int):
310
311
312
313
            if name == "max_model_len":
                kwargs[name]["type"] = human_readable_int_or_auto
                kwargs[name]["help"] += f"\n\n{human_readable_int_or_auto.__doc__}"
            elif name in ("max_num_batched_tokens", "kv_cache_memory_bytes"):
314
                kwargs[name]["type"] = human_readable_int
315
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
316
317
            else:
                kwargs[name]["type"] = int
318
319
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
320
321
322
323
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
324
            kwargs[name]["type"] = union_dict_and_str
325
        elif contains_type(type_hints, dict):
326
            kwargs[name]["type"] = parse_type(json.loads)
327
            kwargs[name]["help"] += f"\n\n{json_tip}"
328
329
330
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
331
332
            kwargs[name]["type"] = str
        else:
333
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
334

335
336
337
338
339
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

340
341
342
343
344
345
346
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
347
348


349
def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
350
351
    """Return argparse kwargs for the given Config dataclass.

352
353
354
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

355
356
357
358
359
360
361
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


362
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
363
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
364
    """Arguments for vLLM engine."""
365

366
    model: str = ModelConfig.model
367
    enable_return_routed_experts: bool = ModelConfig.enable_return_routed_experts
368
    model_weights: str = ModelConfig.model_weights
369
    served_model_name: str | list[str] | None = ModelConfig.served_model_name
370
    tokenizer: str | None = ModelConfig.tokenizer
371
    hf_config_path: str | None = ModelConfig.hf_config_path
372
373
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
374
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
375
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
376
    tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
377
    trust_remote_code: bool = ModelConfig.trust_remote_code
378
379
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
    allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
380
    download_dir: str | None = LoadConfig.download_dir
381
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
382
    load_format: str | LoadFormats = LoadConfig.load_format
383
384
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
385
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
386
    seed: int = ModelConfig.seed
387
    max_model_len: int = ModelConfig.max_model_len
388
389
390
391
392
393
    cudagraph_capture_sizes: list[int] | None = (
        CompilationConfig.cudagraph_capture_sizes
    )
    max_cudagraph_capture_size: int | None = get_field(
        CompilationConfig, "max_cudagraph_capture_size"
    )
394
395
396
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
397
    distributed_executor_backend: (
398
        str | DistributedExecutorBackend | type[Executor] | None
399
    ) = ParallelConfig.distributed_executor_backend
400
    # number of P/D disaggregation (or other disaggregation) workers
401
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
402
403
404
405
    master_addr: str = ParallelConfig.master_addr
    master_port: int = ParallelConfig.master_port
    nnodes: int = ParallelConfig.nnodes
    node_rank: int = ParallelConfig.node_rank
406
    distributed_timeout_seconds: int | None = ParallelConfig.distributed_timeout_seconds
407
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
408
    prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
409
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
410
    dcp_comm_backend: DCPCommBackend = ParallelConfig.dcp_comm_backend
411
    dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
412
    cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
413
    data_parallel_size: int = ParallelConfig.data_parallel_size
414
415
416
417
418
    data_parallel_rank: int | None = None
    data_parallel_start_rank: int | None = None
    data_parallel_size_local: int | None = None
    data_parallel_address: str | None = None
    data_parallel_rpc_port: int | None = None
419
    data_parallel_hybrid_lb: bool = False
420
    data_parallel_external_lb: bool = False
421
    data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend
422
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
423
    moe_backend: MoEBackend = KernelConfig.moe_backend
424
    all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
425
    enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
426
    enable_dbo: bool = ParallelConfig.enable_dbo
427
    ubatch_size: int = ParallelConfig.ubatch_size
428
429
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
430
    disable_nccl_for_dp_synchronization: bool | None = (
431
432
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
433
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
434
    enable_eplb: bool = ParallelConfig.enable_eplb
435
    expert_placement_strategy: ExpertPlacementStrategy = (
436
        ParallelConfig.expert_placement_strategy
437
    )
438
439
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
440
    max_parallel_loading_workers: int | None = (
441
442
        ParallelConfig.max_parallel_loading_workers
    )
443
    block_size: BlockSize = CacheConfig.block_size
444
    enable_prefix_caching: bool | None = None
445
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
446
        CacheConfig.prefix_caching_hash_algo
447
    )
448
449
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
450
    swap_space: float = CacheConfig.swap_space
451
452
453
454
455
456
457
    offload_backend: str = OffloadConfig.offload_backend
    cpu_offload_gb: float = UVAOffloadConfig.cpu_offload_gb
    cpu_offload_params: set[str] = get_field(UVAOffloadConfig, "cpu_offload_params")
    offload_group_size: int = PrefetchOffloadConfig.offload_group_size
    offload_num_in_group: int = PrefetchOffloadConfig.offload_num_in_group
    offload_prefetch_step: int = PrefetchOffloadConfig.offload_prefetch_step
    offload_params: set[str] = get_field(PrefetchOffloadConfig, "offload_params")
458
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
459
    kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
460
    max_num_batched_tokens: int | None = None
461
462
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
463
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
464
    max_num_seqs: int | None = None
465
    max_logprobs: int = ModelConfig.max_logprobs
466
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
467
    disable_log_stats: bool = False
468
    aggregate_engine_logging: bool = False
469
470
471
    revision: str | None = ModelConfig.revision
    code_revision: str | None = ModelConfig.code_revision
    hf_token: bool | str | None = ModelConfig.hf_token
472
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
473
    tokenizer_revision: str | None = ModelConfig.tokenizer_revision
474
    quantization: QuantizationMethods | str | None = ModelConfig.quantization
475
    allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
476
    enforce_eager: bool = ModelConfig.enforce_eager
477
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
478
    language_model_only: bool = MultiModalConfig.language_model_only
479
    limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
480
481
        MultiModalConfig, "limit_per_prompt"
    )
482
    enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
483
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
484
485
486
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
487
    mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
488
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
489
    mm_processor_cache_type: MMCacheType | None = (
490
        MultiModalConfig.mm_processor_cache_type
491
492
    )
    mm_shm_cache_max_object_size_mb: int = (
493
        MultiModalConfig.mm_shm_cache_max_object_size_mb
494
    )
495
    mm_encoder_only: bool = MultiModalConfig.mm_encoder_only
496
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
497
    mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
498
499
        MultiModalConfig.mm_encoder_attn_backend
    )
500
    io_processor_plugin: str | None = None
501
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
502
    video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
503
    # LoRA fields
504
    enable_lora: bool = False
505
    max_loras: int = LoRAConfig.max_loras
506
    max_lora_rank: MaxLoRARanks = LoRAConfig.max_lora_rank
507
    default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras
508
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
509
510
    max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
    lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
511
    enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
512
    specialize_active_lora: bool = LoRAConfig.specialize_active_lora
513

514
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
515
    num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
516
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
517
    ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
518

519
    enable_chunked_prefill: bool | None = None
520
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
521

522
    disable_hybrid_kv_cache_manager: bool | None = (
523
524
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
525

526
    structured_outputs_config: StructuredOutputsConfig = get_field(
527
528
        VllmConfig, "structured_outputs_config"
    )
529
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
530
    reasoning_parser_plugin: str | None = None
531

532
    speculative_config: dict[str, Any] | None = None
533

534
    show_hidden_metrics_for_version: str | None = (
535
        ObservabilityConfig.show_hidden_metrics_for_version
536
    )
537
538
    otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: list[DetailedTraceModules] | None = (
539
        ObservabilityConfig.collect_detailed_traces
540
    )
541
542
543
544
    kv_cache_metrics: bool = ObservabilityConfig.kv_cache_metrics
    kv_cache_metrics_sample: float = get_field(
        ObservabilityConfig, "kv_cache_metrics_sample"
    )
545
    cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
546
547
548
    enable_layerwise_nvtx_tracing: bool = (
        ObservabilityConfig.enable_layerwise_nvtx_tracing
    )
549
    enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
550
551
552
    enable_logging_iteration_details: bool = (
        ObservabilityConfig.enable_logging_iteration_details
    )
553
    enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
554
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
555
    scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
556

557
    pooler_config: PoolerConfig | None = ModelConfig.pooler_config
558
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
559
    attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
560
561
562
563
    kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
    enable_flashinfer_autotune: bool = get_field(
        KernelConfig, "enable_flashinfer_autotune"
    )
564
565
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
566

567
568
    profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")

569
570
    kv_transfer_config: KVTransferConfig | None = None
    kv_events_config: KVEventsConfig | None = None
571

572
573
    ec_transfer_config: ECTransferConfig | None = None

574
575
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
576
577
578
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
579
    model_impl: str = ModelConfig.model_impl
580
    override_attention_dtype: str | None = ModelConfig.override_attention_dtype
581
    attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
582

583
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
584
585
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
586
    mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
587
    mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
588

589
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
590

591
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
592
    pt_load_map_location: str | dict[str, str] = LoadConfig.pt_load_map_location
593

594
    logits_processors: list[str | type[LogitsProcessor]] | None = (
595
596
        ModelConfig.logits_processors
    )
597
598
    """Custom logitproc types"""

599
    async_scheduling: bool | None = SchedulerConfig.async_scheduling
600

601
602
    stream_interval: int = SchedulerConfig.stream_interval

603
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
604
    optimization_level: OptimizationLevel = VllmConfig.optimization_level
605
    performance_mode: PerformanceMode = VllmConfig.performance_mode
606

607
    kv_offloading_size: float | None = CacheConfig.kv_offloading_size
608
    kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
609
    tokens_only: bool = False
610

611
612
    shutdown_timeout: int = 0

613
614
615
616
    weight_transfer_config: WeightTransferConfig | None = get_field(
        VllmConfig,
        "weight_transfer_config",
    )
617

618
619
    fail_on_environ_validation: bool = False

620
    def __post_init__(self):
621
622
623
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
624
        if isinstance(self.compilation_config, dict):
625
            self.compilation_config = CompilationConfig(**self.compilation_config)
626
627
        if isinstance(self.attention_config, dict):
            self.attention_config = AttentionConfig(**self.attention_config)
628
629
        if isinstance(self.kernel_config, dict):
            self.kernel_config = KernelConfig(**self.kernel_config)
630
        if isinstance(self.eplb_config, dict):
631
            self.eplb_config = EPLBConfig(**self.eplb_config)
632
633
634
635
        if isinstance(self.weight_transfer_config, dict):
            self.weight_transfer_config = WeightTransferConfig(
                **self.weight_transfer_config
            )
636
        # Setup plugins
637
        from vllm.plugins import load_general_plugins
638

639
        load_general_plugins()
640
        # when use hf offline,replace model and tokenizer id to local model path
641
642
643
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
            if model_id is not self.model:
                logger.info(
                    "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                    model_id,
                    self.model,
                )
            if self.tokenizer is not None:
                tokenizer_id = self.tokenizer
                self.tokenizer = get_model_path(self.tokenizer, self.tokenizer_revision)
                if tokenizer_id is not self.tokenizer:
                    logger.info(
                        "HF_HUB_OFFLINE is True, replace tokenizer_id [%s] "
                        "to tokenizer_path [%s]",
                        tokenizer_id,
                        self.tokenizer,
                    )
660
661

    @staticmethod
662
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
663
        """Shared CLI arguments for vLLM engine."""
664

665
        # Model arguments
666
667
668
669
670
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
671
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
672
            model_group.add_argument("--model", **model_kwargs["model"])
673
674
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
675
676
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
677
678
679
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
680
681
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
682
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
683
684
685
686
687
688
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
689
        model_group.add_argument("--revision", **model_kwargs["revision"])
690
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
691
692
693
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
694
695
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
696
697
698
699
        model_group.add_argument(
            "--allow-deprecated-quantization",
            **model_kwargs["allow_deprecated_quantization"],
        )
700
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
701
702
703
704
        model_group.add_argument(
            "--enable-return-routed-experts",
            **model_kwargs["enable_return_routed_experts"],
        )
705
706
707
708
709
710
711
712
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
713
714
715
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
716
717
718
719
720
721
722
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
723
724
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
744
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
745
746
747
748
749
750
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
751
752
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
753
        )
754

755
756
757
758
759
760
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
761
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
762
763
764
765
766
767
768
769
770
771
772
773
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
774

775
776
777
778
779
780
781
782
783
784
        # Attention arguments
        attention_kwargs = get_kwargs(AttentionConfig)
        attention_group = parser.add_argument_group(
            title="AttentionConfig",
            description=AttentionConfig.__doc__,
        )
        attention_group.add_argument(
            "--attention-backend", **attention_kwargs["backend"]
        )

785
786
787
788
789
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
790
        )
791
        structured_outputs_group.add_argument(
792
            "--reasoning-parser",
793
            # Choices need to be validated after parsing to include plugins
794
795
            **structured_outputs_kwargs["reasoning_parser"],
        )
796
797
798
799
        structured_outputs_group.add_argument(
            "--reasoning-parser-plugin",
            **structured_outputs_kwargs["reasoning_parser_plugin"],
        )
800

801
        # Parallel arguments
802
803
804
805
806
807
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
808
            "--distributed-executor-backend",
809
810
            **parallel_kwargs["distributed_executor_backend"],
        )
811
        parallel_group.add_argument(
812
813
814
815
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
816
817
818
819
        parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
        parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
        parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
        parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
820
821
822
823
        parallel_group.add_argument(
            "--distributed-timeout-seconds",
            **parallel_kwargs["distributed_timeout_seconds"],
        )
824
        parallel_group.add_argument(
825
826
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
827
        parallel_group.add_argument(
828
829
830
831
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
832
833
834
835
        parallel_group.add_argument(
            "--dcp-comm-backend",
            **parallel_kwargs["dcp_comm_backend"],
        )
836
837
838
839
        parallel_group.add_argument(
            "--dcp-kv-cache-interleave-size",
            **parallel_kwargs["dcp_kv_cache_interleave_size"],
        )
840
841
842
843
844
845
846
847
848
        parallel_group.add_argument(
            "--cp-kv-cache-interleave-size",
            **parallel_kwargs["cp_kv_cache_interleave_size"],
        )
        parallel_group.add_argument(
            "--prefill-context-parallel-size",
            "-pcp",
            **parallel_kwargs["prefill_context_parallel_size"],
        )
849
850
851
852
853
854
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
855
            type=int,
856
857
858
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
859
        parallel_group.add_argument(
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
890
        parallel_group.add_argument(
891
892
893
894
895
896
897
898
            "--data-parallel-hybrid-lb",
            "-dph",
            **parallel_kwargs["data_parallel_hybrid_lb"],
        )
        parallel_group.add_argument(
            "--data-parallel-external-lb",
            "-dpe",
            **parallel_kwargs["data_parallel_external_lb"],
899
900
        )
        parallel_group.add_argument(
901
902
903
            "--enable-expert-parallel",
            "-ep",
            **parallel_kwargs["enable_expert_parallel"],
904
        )
905
906
907
        parallel_group.add_argument(
            "--all2all-backend", **parallel_kwargs["all2all_backend"]
        )
908
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
909
910
911
912
        parallel_group.add_argument(
            "--ubatch-size",
            **parallel_kwargs["ubatch_size"],
        )
913
914
915
        parallel_group.add_argument(
            "--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"]
        )
916
917
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
918
919
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
920
921
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
922
923
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
924
925
926
927
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
928
929
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
930
931
        parallel_group.add_argument(
            "--expert-placement-strategy",
932
933
            **parallel_kwargs["expert_placement_strategy"],
        )
934

935
        parallel_group.add_argument(
936
            "--max-parallel-loading-workers",
937
938
            **parallel_kwargs["max_parallel_loading_workers"],
        )
939
        parallel_group.add_argument(
940
941
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
942
        parallel_group.add_argument(
943
            "--disable-custom-all-reduce",
944
945
946
947
948
949
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
950

951
952
953
954
955
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
956
        )
957
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
958
959
960
961
962
963
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
964
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
965
966
967
968
969
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
970
971
972
973
974
            "--enable-prefix-caching",
            **{
                **cache_kwargs["enable_prefix_caching"],
                "default": None,
            },
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
991
992
993
        cache_group.add_argument(
            "--mamba-block-size", **cache_kwargs["mamba_block_size"]
        )
994
995
996
        cache_group.add_argument(
            "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
        )
997
998
999
1000
1001
1002
        cache_group.add_argument(
            "--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
        )
        cache_group.add_argument(
            "--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
        )
1003

1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        # Model weight offload related configs
        offload_kwargs = get_kwargs(OffloadConfig)
        uva_kwargs = get_kwargs(UVAOffloadConfig)
        prefetch_kwargs = get_kwargs(PrefetchOffloadConfig)
        offload_group = parser.add_argument_group(
            title="OffloadConfig",
            description=OffloadConfig.__doc__,
        )
        offload_group.add_argument(
            "--offload-backend", **offload_kwargs["offload_backend"]
        )
        offload_group.add_argument("--cpu-offload-gb", **uva_kwargs["cpu_offload_gb"])
        offload_group.add_argument(
            "--cpu-offload-params", **uva_kwargs["cpu_offload_params"]
        )
        offload_group.add_argument(
            "--offload-group-size",
            **prefetch_kwargs["offload_group_size"],
        )
        offload_group.add_argument(
            "--offload-num-in-group",
            **prefetch_kwargs["offload_num_in_group"],
        )
        offload_group.add_argument(
            "--offload-prefetch-step",
            **prefetch_kwargs["offload_prefetch_step"],
        )
        offload_group.add_argument(
            "--offload-params", **prefetch_kwargs["offload_params"]
        )

1035
        # Multimodal related configs
1036
1037
1038
1039
1040
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
1041
1042
1043
        multimodal_group.add_argument(
            "--language-model-only", **multimodal_kwargs["language_model_only"]
        )
1044
        multimodal_group.add_argument(
1045
1046
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
1047
1048
1049
        multimodal_group.add_argument(
            "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
        )
1050
1051
1052
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
1053
1054
1055
1056
1057
1058
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
1059
        multimodal_group.add_argument(
1060
1061
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
1062
1063
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
1064
1065
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
1066
1067
1068
        multimodal_group.add_argument(
            "--mm-encoder-only", **multimodal_kwargs["mm_encoder_only"]
        )
1069
        multimodal_group.add_argument(
1070
1071
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
1072
1073
1074
1075
        multimodal_group.add_argument(
            "--mm-encoder-attn-backend",
            **multimodal_kwargs["mm_encoder_attn_backend"],
        )
1076
1077
1078
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
1079
        multimodal_group.add_argument(
1080
1081
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
1082

1083
        multimodal_group.add_argument(
1084
1085
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
1086

1087
        # LoRA related configs
1088
1089
1090
1091
1092
1093
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
1094
            "--enable-lora",
1095
            action=argparse.BooleanOptionalAction,
1096
1097
            help="If True, enable handling of LoRA adapters.",
        )
1098
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
1099
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
1100
        lora_group.add_argument(
1101
            "--lora-dtype",
1102
1103
            **lora_kwargs["lora_dtype"],
        )
1104
1105
1106
1107
        lora_group.add_argument(
            "--enable-tower-connector-lora",
            **lora_kwargs["enable_tower_connector_lora"],
        )
1108
1109
1110
1111
1112
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
1113
1114
1115
        lora_group.add_argument(
            "--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
        )
1116

1117
1118
1119
1120
1121
1122
1123
1124
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
1125
1126
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
1127
        observability_group.add_argument(
1128
1129
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
1130
1131
1132
1133
1134
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
1135
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
1136
1137
1138
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
1139
1140
            **observability_kwargs["collect_detailed_traces"],
        )
1141
1142
1143
1144
1145
1146
1147
        observability_group.add_argument(
            "--kv-cache-metrics", **observability_kwargs["kv_cache_metrics"]
        )
        observability_group.add_argument(
            "--kv-cache-metrics-sample",
            **observability_kwargs["kv_cache_metrics_sample"],
        )
1148
1149
1150
1151
        observability_group.add_argument(
            "--cudagraph-metrics",
            **observability_kwargs["cudagraph_metrics"],
        )
1152
1153
1154
1155
        observability_group.add_argument(
            "--enable-layerwise-nvtx-tracing",
            **observability_kwargs["enable_layerwise_nvtx_tracing"],
        )
1156
1157
1158
1159
        observability_group.add_argument(
            "--enable-mfu-metrics",
            **observability_kwargs["enable_mfu_metrics"],
        )
1160
1161
1162
1163
        observability_group.add_argument(
            "--enable-logging-iteration-details",
            **observability_kwargs["enable_logging_iteration_details"],
        )
1164

1165
1166
1167
1168
1169
1170
1171
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
1172
1173
1174
1175
1176
            "--max-num-batched-tokens",
            **{
                **scheduler_kwargs["max_num_batched_tokens"],
                "default": None,
            },
1177
        )
1178
        scheduler_group.add_argument(
1179
1180
1181
1182
1183
            "--max-num-seqs",
            **{
                **scheduler_kwargs["max_num_seqs"],
                "default": None,
            },
1184
1185
1186
1187
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
1188
1189
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
1190
1191
            **scheduler_kwargs["max_long_partial_prefills"],
        )
1192
1193
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
1194
1195
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
1196
1197
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
1198
        scheduler_group.add_argument(
1199
1200
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
1201
        scheduler_group.add_argument(
1202
1203
1204
1205
1206
            "--enable-chunked-prefill",
            **{
                **scheduler_kwargs["enable_chunked_prefill"],
                "default": None,
            },
1207
1208
1209
1210
1211
1212
1213
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1214
1215
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1216
1217
1218
1219
1220
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1221
1222
1223
        scheduler_group.add_argument(
            "--stream-interval", **scheduler_kwargs["stream_interval"]
        )
1224

1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
        # Compilation arguments
        compilation_kwargs = get_kwargs(CompilationConfig)
        compilation_group = parser.add_argument_group(
            title="CompilationConfig",
            description=CompilationConfig.__doc__,
        )
        compilation_group.add_argument(
            "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
        )
        compilation_group.add_argument(
            "--max-cudagraph-capture-size",
            **compilation_kwargs["max_cudagraph_capture_size"],
        )

1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
        # Kernel arguments
        kernel_kwargs = get_kwargs(KernelConfig)
        kernel_group = parser.add_argument_group(
            title="KernelConfig",
            description=KernelConfig.__doc__,
        )
        kernel_group.add_argument(
            "--enable-flashinfer-autotune",
            **kernel_kwargs["enable_flashinfer_autotune"],
        )
1249
1250
1251
        moe_backend_kwargs = kernel_kwargs["moe_backend"]
        moe_backend_kwargs["type"] = lambda s: s.lower().replace("-", "_")
        kernel_group.add_argument("--moe-backend", **moe_backend_kwargs)
1252

1253
        # vLLM arguments
1254
        vllm_kwargs = get_kwargs(VllmConfig)
1255
1256
1257
1258
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1259
1260
1261
1262
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1263
1264
1265
1266
1267
1268
1269
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
1270
1271
1272
        vllm_group.add_argument(
            "--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
        )
1273
        vllm_group.add_argument(
1274
            "--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
1275
        )
1276
1277
1278
        vllm_group.add_argument(
            "--attention-config", "-ac", **vllm_kwargs["attention_config"]
        )
1279
        vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
1280
1281
1282
1283
1284
1285
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1286
        vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
1287
1288
1289
        vllm_group.add_argument(
            "--optimization-level", **vllm_kwargs["optimization_level"]
        )
1290
        vllm_group.add_argument("--performance-mode", **vllm_kwargs["performance_mode"])
1291
1292
1293
        vllm_group.add_argument(
            "--weight-transfer-config", **vllm_kwargs["weight_transfer_config"]
        )
1294

1295
        # Other arguments
1296
1297
1298
1299
1300
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1301

1302
1303
1304
1305
1306
1307
        parser.add_argument(
            "--aggregate-engine-logging",
            action="store_true",
            help="Log aggregate rather than per-engine statistics "
            "when using data parallelism.",
        )
1308
1309
1310
1311
1312
1313
1314
1315

        parser.add_argument(
            "--fail-on-environ-validation",
            help="If set, the engine will raise an error if "
            "environment validation fails.",
            default=False,
            action=argparse.BooleanOptionalAction,
        )
1316
1317
1318
1319
1320
1321
1322
1323

        parser.add_argument(
            "--shutdown-timeout",
            type=int,
            default=0,
            help="Shutdown timeout in seconds. 0 = abort, >0 = wait.",
        )

1324
        return parser
1325
1326

    @classmethod
1327
    def from_cli_args(cls, args: argparse.Namespace):
1328
1329
1330
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1331
1332
1333
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1334
        return engine_args
1335

1336
    def create_model_config(self) -> ModelConfig:
1337
1338
        # gguf file needs a specific model loader
        if is_gguf(self.model):
1339
1340
            self.quantization = self.load_format = "gguf"

1341
1342
1343
1344
1345
1346
1347
        if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
            logger.warning(
                "The global random seed is set to %d. Since "
                "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
                "affect the random state of the Python process that "
                "launched vLLM.",
                self.seed,
1348
1349
            )

1350
        return ModelConfig(
1351
            model=self.model,
1352
            model_weights=self.model_weights,
1353
            hf_config_path=self.hf_config_path,
1354
1355
            runner=self.runner,
            convert=self.convert,
1356
            tokenizer=self.tokenizer,  # type: ignore[arg-type]
1357
            tokenizer_mode=self.tokenizer_mode,
1358
            trust_remote_code=self.trust_remote_code,
1359
1360
            allowed_local_media_path=self.allowed_local_media_path,
            allowed_media_domains=self.allowed_media_domains,
1361
1362
1363
1364
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
1365
            hf_token=self.hf_token,
1366
            hf_overrides=self.hf_overrides,
1367
            tokenizer_revision=self.tokenizer_revision,
1368
1369
            max_model_len=self.max_model_len,
            quantization=self.quantization,
1370
            allow_deprecated_quantization=self.allow_deprecated_quantization,
1371
            enforce_eager=self.enforce_eager,
1372
            enable_return_routed_experts=self.enable_return_routed_experts,
1373
            max_logprobs=self.max_logprobs,
1374
            logprobs_mode=self.logprobs_mode,
1375
            disable_sliding_window=self.disable_sliding_window,
1376
            disable_cascade_attn=self.disable_cascade_attn,
1377
            skip_tokenizer_init=self.skip_tokenizer_init,
1378
            enable_prompt_embeds=self.enable_prompt_embeds,
1379
            served_model_name=self.served_model_name,
1380
            language_model_only=self.language_model_only,
1381
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1382
            enable_mm_embeds=self.enable_mm_embeds,
1383
            interleave_mm_strings=self.interleave_mm_strings,
1384
            media_io_kwargs=self.media_io_kwargs,
1385
            skip_mm_profiling=self.skip_mm_profiling,
1386
            config_format=self.config_format,
1387
            mm_processor_kwargs=self.mm_processor_kwargs,
1388
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1389
            mm_processor_cache_type=self.mm_processor_cache_type,
1390
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1391
            mm_encoder_only=self.mm_encoder_only,
1392
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1393
            mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1394
            pooler_config=self.pooler_config,
1395
            generation_config=self.generation_config,
1396
            override_generation_config=self.override_generation_config,
1397
            enable_sleep_mode=self.enable_sleep_mode,
1398
            model_impl=self.model_impl,
1399
            override_attention_dtype=self.override_attention_dtype,
1400
            logits_processors=self.logits_processors,
1401
            video_pruning_rate=self.video_pruning_rate,
1402
            io_processor_plugin=self.io_processor_plugin,
1403
        )
1404

1405
    def validate_tensorizer_args(self):
1406
1407
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1408
1409
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1410
1411
1412
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1413

1414
    def create_load_config(self) -> LoadConfig:
1415
1416
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1417

1418
1419
1420
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1421
1422
                    self.model_loader_extra_config.to_serializable()
                )
1423
            self.model_loader_extra_config["tensorizer_config"] = {}
1424
1425
1426
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1427
            self.validate_tensorizer_args()
1428

1429
1430
1431
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1432
            safetensors_load_strategy=self.safetensors_load_strategy,
1433
1434
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1435
            use_tqdm_on_load=self.use_tqdm_on_load,
1436
            pt_load_map_location=self.pt_load_map_location,
1437
        )
1438

1439
1440
1441
1442
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
1443
    ) -> SpeculativeConfig | None:
1444
1445
1446
1447
1448
1449
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1450
        dictionary from the engine.
1451
1452
        """
        if self.speculative_config is None:
1453
            return None
1454

1455
1456
1457
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1458
1459
1460
1461
1462
1463
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
            }
        )
1464
        return SpeculativeConfig(**self.speculative_config)
1465

1466
1467
    def create_engine_config(
        self,
1468
        usage_context: UsageContext | None = None,
1469
        headless: bool = False,
1470
1471
1472
1473
    ) -> VllmConfig:
        """
        Create the VllmConfig.

1474
        NOTE: If VllmConfig is incompatible, we raise an error.
1475
        """
1476
        current_platform.pre_register_and_update()
1477

1478
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1479

1480
1481
        envs.validate_environ(self.fail_on_environ_validation)

1482
1483
        # Check if the model is a speculator and override model/tokenizer/config
        # BEFORE creating ModelConfig, so the config is created with the target model
1484
1485
1486
1487
        # Skip speculator detection for cloud storage models (eg: S3, GCS) since
        # HuggingFace cannot load configs directly from S3 URLs. S3 models can still
        # use speculators with explicit --speculative-config.
        if not is_cloud_storage(self.model):
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
            (self.model, self.tokenizer, self.speculative_config) = (
                maybe_override_with_speculators(
                    model=self.model,
                    tokenizer=self.tokenizer,
                    revision=self.revision,
                    trust_remote_code=self.trust_remote_code,
                    vllm_speculative_config=self.speculative_config,
                )
            )

1498
        model_config = self.create_model_config()
1499
        self.model = model_config.model
1500
        self.model_weights = model_config.model_weights
1501
1502
        self.tokenizer = model_config.tokenizer

1503
        self._check_feature_supported()
1504
1505
1506
1507
        self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
        self._set_default_max_num_seqs_and_batched_tokens_args(
            usage_context, model_config
        )
1508

1509
        sliding_window: int | None = None
1510
1511
1512
1513
1514
1515
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1516
1517
1518
1519
1520
        # Resolve "auto" kv_cache_dtype to actual value from model config
        resolved_cache_dtype = resolve_kv_cache_dtype_string(
            self.kv_cache_dtype, model_config
        )

1521
1522
1523
1524
        assert self.enable_prefix_caching is not None, (
            "enable_prefix_caching must be set by this point"
        )

1525
        cache_config = CacheConfig(
1526
            block_size=self.block_size,
1527
            gpu_memory_utilization=self.gpu_memory_utilization,
1528
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1529
            swap_space=self.swap_space,
1530
            cache_dtype=resolved_cache_dtype,  # type: ignore[arg-type]
1531
            is_attention_free=model_config.is_attention_free,
1532
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1533
            sliding_window=sliding_window,
1534
            enable_prefix_caching=self.enable_prefix_caching,
1535
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1536
            calculate_kv_scales=self.calculate_kv_scales,
1537
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1538
1539
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1540
            mamba_block_size=self.mamba_block_size,
1541
            mamba_cache_mode=self.mamba_cache_mode,
1542
1543
            kv_offloading_size=self.kv_offloading_size,
            kv_offloading_backend=self.kv_offloading_backend,
1544
        )
1545

1546
1547
1548
1549
1550
1551
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1552

1553
            ray_runtime_env = ray.get_runtime_context().runtime_env
1554
1555
1556
1557
1558
1559
1560
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1561

1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1573
        assert not headless or not self.data_parallel_hybrid_lb, (
1574
1575
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
        assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
            "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
        )
        assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
            "nnodes > 1 is only supported with data_parallel_backend=mp"
        )
        inferred_data_parallel_rank = 0
        if self.nnodes > 1:
            world_size = (
                self.data_parallel_size
                * self.pipeline_parallel_size
                * self.tensor_parallel_size
            )
            world_size_within_dp = (
                self.pipeline_parallel_size * self.tensor_parallel_size
            )
            local_world_size = world_size // self.nnodes
            assert world_size % self.nnodes == 0, (
                f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
            )
            assert self.node_rank < self.nnodes, (
                f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
            )
            inferred_data_parallel_rank = (
                self.node_rank * local_world_size
            ) // world_size_within_dp
            if self.data_parallel_size > 1 and self.data_parallel_external_lb:
                self.data_parallel_rank = inferred_data_parallel_rank
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d for external lb",
                    self.data_parallel_rank,
                    self.node_rank,
                )
            elif self.data_parallel_size_local is None:
                # Infer data parallel size local for internal dplb:
                self.data_parallel_size_local = max(
                    local_world_size // world_size_within_dp, 1
                )
        data_parallel_external_lb = (
            self.data_parallel_external_lb or self.data_parallel_rank is not None
        )
1617
        # Local DP rank = 1, use pure-external LB.
1618
        if data_parallel_external_lb:
1619
            assert self.data_parallel_rank is not None, (
1620
                "data_parallel_rank or node_rank must be specified if "
1621
1622
                "data_parallel_external_lb is enable."
            )
1623
            assert self.data_parallel_size_local in (1, None), (
1624
1625
                "data_parallel_size_local must be 1 or None when data_parallel_rank "
                "is set"
1626
            )
1627
            data_parallel_size_local = 1
1628
1629
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1630
1631
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1632
1633
1634
1635
1636
1637
1638

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
1639
1640
1641
1642
1643
                logger.warning(
                    "data_parallel_hybrid_lb is not eligible when "
                    "data_parallel_size_local = 1, autoswitch to "
                    "data_parallel_external_lb."
                )
1644
1645
1646
1647
1648
1649
1650
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

1651
1652
1653
1654
1655
1656
1657
1658
1659
            self.data_parallel_rank = (
                self.data_parallel_start_rank or inferred_data_parallel_rank
            )
            if self.nnodes > 1:
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d",
                    self.data_parallel_rank,
                    self.node_rank,
                )
1660
        else:
1661
            assert not self.data_parallel_hybrid_lb, (
1662
1663
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1664

1665
1666
1667
1668
1669
1670
1671
1672
1673
            if self.data_parallel_backend == "ray" and (
                envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
            ):
                # Data parallel size defaults to 1 if DP ranks are spanning
                # multiple nodes
                data_parallel_size_local = 1
            else:
                # Otherwise local DP size defaults to global DP size if not set
                data_parallel_size_local = self.data_parallel_size
1674
1675
1676

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1677
1678
1679
1680
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1681
1682
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1683
1684
1685
1686
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1687
1688
                    self.data_parallel_backend,
                )
1689
1690
1691
                data_parallel_address = (
                    self.master_addr or ParallelConfig.data_parallel_master_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1692
1693
        else:
            data_parallel_address = self.data_parallel_address
1694
1695
1696

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1697
        data_parallel_rpc_port = (
1698
            self.data_parallel_rpc_port
1699
1700
1701
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1702

1703
1704
1705
1706
        if self.tokens_only and not model_config.skip_tokenizer_init:
            model_config.skip_tokenizer_init = True
            logger.info("Skipping tokenizer initialization for tokens-only mode.")

1707
        parallel_config = ParallelConfig(
1708
1709
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1710
            prefill_context_parallel_size=self.prefill_context_parallel_size,
1711
            data_parallel_size=self.data_parallel_size,
1712
1713
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1714
            data_parallel_size_local=data_parallel_size_local,
1715
1716
1717
1718
            master_addr=self.master_addr,
            master_port=self.master_port,
            nnodes=self.nnodes,
            node_rank=self.node_rank,
1719
            distributed_timeout_seconds=self.distributed_timeout_seconds,
1720
1721
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1722
            data_parallel_backend=self.data_parallel_backend,
1723
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1724
            is_moe_model=model_config.is_moe,
1725
            enable_expert_parallel=self.enable_expert_parallel,
1726
            all2all_backend=self.all2all_backend,
1727
            enable_elastic_ep=self.enable_elastic_ep,
1728
            enable_dbo=self.enable_dbo,
1729
            ubatch_size=self.ubatch_size,
1730
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1731
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1732
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1733
            enable_eplb=self.enable_eplb,
1734
            eplb_config=self.eplb_config,
1735
            expert_placement_strategy=self.expert_placement_strategy,
1736
1737
1738
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1739
            ray_runtime_env=ray_runtime_env,
1740
            placement_group=placement_group,
1741
1742
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1743
            worker_extension_cls=self.worker_extension_cls,
1744
            decode_context_parallel_size=self.decode_context_parallel_size,
1745
            dcp_comm_backend=self.dcp_comm_backend,
1746
            dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
1747
            cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
1748
1749
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1750
        )
1751

1752
        speculative_config = self.create_speculative_config(
1753
1754
1755
1756
            target_model_config=model_config,
            target_parallel_config=parallel_config,
        )

1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
        assert self.max_num_batched_tokens is not None, (
            "max_num_batched_tokens must be set by this point"
        )
        assert self.max_num_seqs is not None, "max_num_seqs must be set by this point"
        assert self.enable_chunked_prefill is not None, (
            "enable_chunked_prefill must be set by this point"
        )
        assert model_config.max_model_len is not None, (
            "max_model_len must be set by this point"
        )
1767
        scheduler_config = SchedulerConfig(
1768
            runner_type=model_config.runner_type,
1769
1770
1771
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1772
            enable_chunked_prefill=self.enable_chunked_prefill,
1773
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1774
            is_multimodal_model=model_config.is_multimodal_model,
1775
            is_encoder_decoder=model_config.is_encoder_decoder,
1776
            policy=self.scheduling_policy,
1777
            scheduler_cls=self.scheduler_cls,
1778
1779
1780
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1781
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1782
            async_scheduling=self.async_scheduling,
1783
            stream_interval=self.stream_interval,
1784
        )
1785

1786
1787
1788
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_dtype=self.lora_dtype,
1799
                enable_tower_connector_lora=self.enable_tower_connector_lora,
1800
                specialize_active_lora=self.specialize_active_lora,
1801
1802
1803
1804
1805
1806
1807
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1808

1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
        if (
            lora_config is not None
            and speculative_config is not None
            and scheduler_config.max_num_batched_tokens
            < (
                scheduler_config.max_num_seqs
                * (speculative_config.num_speculative_tokens + 1)
            )
        ):
            raise ValueError(
                "Consider increasing max_num_batched_tokens or "
                "decreasing num_speculative_tokens"
            )

1823
1824
1825
1826
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1827
1828
1829
1830
1831
1832
1833
1834
        # Attention config overrides
        attention_config = copy.deepcopy(self.attention_config)
        if self.attention_backend is not None:
            if attention_config.backend is not None:
                raise ValueError(
                    "attention_backend and attention_config.backend "
                    "are mutually exclusive"
                )
1835
1836
1837
1838
            # Reuse the validator to handle "auto" and string-to-enum conversion
            attention_config.backend = AttentionConfig.validate_backend_before(
                self.attention_backend
            )
1839

1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
        # Kernel config overrides
        kernel_config = copy.deepcopy(self.kernel_config)
        if self.enable_flashinfer_autotune is not None:
            if kernel_config.enable_flashinfer_autotune is not None:
                raise ValueError(
                    "enable_flashinfer_autotune and "
                    "kernel_config.enable_flashinfer_autotune "
                    "are mutually exclusive"
                )
            kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune
1850
1851
        if self.moe_backend != "auto":
            kernel_config.moe_backend = self.moe_backend
1852

1853
        load_config = self.create_load_config()
1854

1855
1856
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1857
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1858

1859
1860
1861
1862
1863
        if self.reasoning_parser_plugin:
            self.structured_outputs_config.reasoning_parser_plugin = (
                self.reasoning_parser_plugin
            )

1864
        observability_config = ObservabilityConfig(
1865
            show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
1866
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1867
            collect_detailed_traces=self.collect_detailed_traces,
1868
1869
            kv_cache_metrics=self.kv_cache_metrics,
            kv_cache_metrics_sample=self.kv_cache_metrics_sample,
1870
            cudagraph_metrics=self.cudagraph_metrics,
1871
            enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
1872
            enable_mfu_metrics=self.enable_mfu_metrics,
1873
            enable_mm_processor_stats=self.enable_mm_processor_stats,
1874
            enable_logging_iteration_details=self.enable_logging_iteration_details,
1875
        )
1876

1877
        # Compilation config overrides
1878
        compilation_config = copy.deepcopy(self.compilation_config)
1879
        if self.cudagraph_capture_sizes is not None:
1880
            if compilation_config.cudagraph_capture_sizes is not None:
1881
1882
1883
1884
                raise ValueError(
                    "cudagraph_capture_sizes and compilation_config."
                    "cudagraph_capture_sizes are mutually exclusive"
                )
1885
            compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes
1886
        if self.max_cudagraph_capture_size is not None:
1887
            if compilation_config.max_cudagraph_capture_size is not None:
1888
1889
1890
1891
                raise ValueError(
                    "max_cudagraph_capture_size and compilation_config."
                    "max_cudagraph_capture_size are mutually exclusive"
                )
1892
            compilation_config.max_cudagraph_capture_size = (
1893
1894
                self.max_cudagraph_capture_size
            )
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909

        offload_config = OffloadConfig(
            offload_backend=self.offload_backend,
            uva=UVAOffloadConfig(
                cpu_offload_gb=self.cpu_offload_gb,
                cpu_offload_params=self.cpu_offload_params,
            ),
            prefetch=PrefetchOffloadConfig(
                offload_group_size=self.offload_group_size,
                offload_num_in_group=self.offload_num_in_group,
                offload_prefetch_step=self.offload_prefetch_step,
                offload_params=self.offload_params,
            ),
        )

1910
        config = VllmConfig(
1911
1912
1913
1914
1915
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
1916
            load_config=load_config,
1917
            offload_config=offload_config,
1918
            attention_config=attention_config,
1919
            kernel_config=kernel_config,
1920
1921
            lora_config=lora_config,
            speculative_config=speculative_config,
1922
            structured_outputs_config=self.structured_outputs_config,
1923
            observability_config=observability_config,
1924
            compilation_config=compilation_config,
1925
            kv_transfer_config=self.kv_transfer_config,
1926
            kv_events_config=self.kv_events_config,
1927
            ec_transfer_config=self.ec_transfer_config,
1928
            profiler_config=self.profiler_config,
1929
            additional_config=self.additional_config,
1930
            optimization_level=self.optimization_level,
1931
            performance_mode=self.performance_mode,
1932
            weight_transfer_config=self.weight_transfer_config,
1933
            shutdown_timeout=self.shutdown_timeout,
1934
        )
1935

1936
1937
        return config

1938
    def _check_feature_supported(self):
1939
        """Raise an error if the feature is not supported."""
1940
        # No Concurrent Partial Prefills so far.
1941
1942
1943
1944
1945
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
1946
            _raise_unsupported_error(feature_name="Concurrent Partial Prefill")
1947

1948
        if self.pipeline_parallel_size > 1:
1949
1950
1951
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1952
            if not supports_pp and self.distributed_executor_backend not in (
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
1963
                _raise_unsupported_error(feature_name=name)
1964

1965
1966
1967
1968
1969
1970
    @classmethod
    def get_batch_defaults(
        cls,
        world_size: int,
    ) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
        from vllm.usage.usage_lib import UsageContext
1971

1972
1973
        default_max_num_batched_tokens: dict[UsageContext | None, int]
        default_max_num_seqs: dict[UsageContext | None, int]
1974

1975
1976
        # When no user override, set the default values based on the usage
        # context.
1977
        # Use different default values for different hardware.
1978
1979
1980
1981
1982
1983
1984

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1985
            device_memory = current_platform.get_device_total_memory()
1986
            device_name = current_platform.get_device_name().lower()
1987
1988
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1989
            device_memory = 0
1990
            device_name = ""
1991

1992
1993
1994
1995
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1996
            # For GPUs like H100 and MI300x, use larger default values.
1997
1998
1999
2000
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
2001
2002
2003
2004
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
2005
2006
2007
2008
2009
2010
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
2011
2012
2013
2014
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
2015

2016
2017
        # tpu specific default values.
        if current_platform.is_tpu():
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
            chip_name = current_platform.get_device_name()

            if chip_name == "V6E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 2048,
                    UsageContext.OPENAI_API_SERVER: 1024,
                }
            elif chip_name == "V5E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 1024,
                    UsageContext.OPENAI_API_SERVER: 512,
                }
            elif chip_name == "V5P":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 512,
                    UsageContext.OPENAI_API_SERVER: 256,
                }
2035

2036
2037
2038
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
2039
2040
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
2041
2042
            }
            default_max_num_seqs = {
2043
2044
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
2045
2046
            }

2047
2048
        return default_max_num_batched_tokens, default_max_num_seqs

2049
2050
    def _set_default_chunked_prefill_and_prefix_caching_args(
        self, model_config: ModelConfig
2051
    ) -> None:
2052
2053
        default_chunked_prefill = model_config.is_chunked_prefill_supported
        default_prefix_caching = model_config.is_prefix_caching_supported
2054
2055
2056
2057
2058
2059
2060
2061

        if self.enable_chunked_prefill is None:
            self.enable_chunked_prefill = default_chunked_prefill

            logger.debug(
                "%s chunked prefill by default",
                "Enabling" if default_chunked_prefill else "Disabling",
            )
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
        elif (
            model_config.runner_type == "generate"
            and not self.enable_chunked_prefill
            and default_chunked_prefill
        ):
            logger.warning_once(
                "This model does not officially support disabling chunked prefill. "
                "Disabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
                scope="local",
            )
2073
2074
2075
2076
        elif (
            model_config.runner_type == "pooling"
            and self.enable_chunked_prefill
            and not default_chunked_prefill
2077
        ):
2078
            logger.warning_once(
2079
2080
2081
                "This model does not officially support chunked prefill. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
2082
                scope="local",
2083
2084
2085
2086
2087
            )

        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = default_prefix_caching

2088
            logger.debug(
2089
2090
2091
2092
2093
2094
2095
2096
                "%s prefix caching by default",
                "Enabling" if default_prefix_caching else "Disabling",
            )
        elif (
            model_config.runner_type == "pooling"
            and self.enable_prefix_caching
            and not default_prefix_caching
        ):
2097
            logger.warning_once(
2098
2099
2100
                "This model does not officially support prefix caching. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
2101
                scope="local",
2102
2103
            )

2104
        # Disable chunked prefill and prefix caching for:
2105
        # RISCV CPUs in V1
2106
2107
2108
2109
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.RISCV,
        ):
            logger.info(
2110
2111
                "Chunked prefill is not supported for"
                "RISC-V CPUs; "
2112
2113
2114
2115
                "disabling it for V1 backend."
            )
            self.enable_chunked_prefill = False
            logger.info(
2116
2117
                "Prefix caching is not supported for "
                "RISC-V CPUs; "
2118
2119
2120
2121
2122
                "disabling it for V1 backend."
            )
            self.enable_prefix_caching = False

    def _set_default_max_num_seqs_and_batched_tokens_args(
2123
2124
2125
        self,
        usage_context: UsageContext | None,
        model_config: ModelConfig,
2126
    ):
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
        world_size = self.pipeline_parallel_size * self.tensor_parallel_size
        (
            default_max_num_batched_tokens,
            default_max_num_seqs,
        ) = self.get_batch_defaults(world_size)

        orig_max_num_batched_tokens = self.max_num_batched_tokens
        orig_max_num_seqs = self.max_num_seqs

        if self.max_num_batched_tokens is None:
            self.max_num_batched_tokens = default_max_num_batched_tokens.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )

        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
            )

2148
2149
2150
2151
2152
2153
2154
        # If throughput mode is set, double max_num_batched_tokens and max_num_seqs.
        if self.performance_mode == "throughput":
            if orig_max_num_batched_tokens is None:
                self.max_num_batched_tokens *= 2
            if orig_max_num_seqs is None:
                self.max_num_seqs *= 2

2155
        if orig_max_num_batched_tokens is None:
2156
2157
2158
            assert model_config.max_model_len is not None, (
                "max_model_len must be set by this point"
            )
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
            if not self.enable_chunked_prefill:
                # If max_model_len is too short, use the default for higher throughput.
                self.max_num_batched_tokens = max(
                    model_config.max_model_len,
                    self.max_num_batched_tokens,
                )

            # When using default settings,
            # Ensure max_num_batched_tokens does not exceed model limit.
            # Some models (e.g., Whisper) have embeddings tied to max length.
            self.max_num_batched_tokens = min(
                self.max_num_seqs * model_config.max_model_len,
2171
2172
                self.max_num_batched_tokens,
            )
2173

2174
2175
2176
2177
            logger.debug(
                "Defaulting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens,
                usage_context.value if usage_context else None,
2178
            )
2179

2180
2181
2182
2183
        if orig_max_num_seqs is None:
            assert self.max_num_batched_tokens is not None  # For type checking
            self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

2184
            logger.debug(
2185
                "Defaulting max_num_seqs to %d for %s usage context.",
2186
                self.max_num_seqs,
2187
                usage_context.value if usage_context else None,
2188
            )
2189

2190

2191
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
2192
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
2193
    """Arguments for asynchronous vLLM engine."""
2194

2195
2196
    enable_log_requests: bool = False

2197
    @staticmethod
2198
2199
2200
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
2201
        # Initialize plugin to update the parser, for example, The plugin may
2202
        # add a new kind of quantization method to --quantization argument or
2203
2204
        # a new device to --device argument.
        load_general_plugins()
2205
2206
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
2207
2208
2209
2210
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
2211
2212
2213
2214
            help="Enable logging request information, dependant on log level:\n"
            "- INFO: Request ID, parameters and LoRA request.\n"
            "- DEBUG: Prompt inputs (e.g: text, token IDs).\n"
            "You can set the minimum log level via `VLLM_LOGGING_LEVEL`.",
2215
        )
2216
        current_platform.pre_register_and_update(parser)
2217
        return parser
2218
2219


2220
2221
2222
2223
2224
2225
def _raise_unsupported_error(feature_name: str):
    msg = (
        f"{feature_name} is not supported. We recommend to "
        f"remove {feature_name} from your config."
    )
    raise NotImplementedError(msg)
2226
2227


2228
def human_readable_int(value: str) -> int:
2229
2230
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
2231

2232
2233
2234
2235
2236
2237
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
2238

2239
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
2240
2241
    if match:
        decimal_multiplier = {
2242
2243
2244
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
2245
            "t": 10**12,
2246
2247
        }
        binary_multiplier = {
2248
2249
2250
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
2251
            "T": 2**40,
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
2264
2265
2266
2267
2268
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
2269
2270
2271

    # Regular plain number.
    return int(value)
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290


def human_readable_int_or_auto(value: str) -> int:
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    Also accepts -1 or 'auto' as a special value for auto-detection.

    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    - '-1' or 'auto' -> -1 (special value for auto-detection)
    """
    value = value.strip()

    if value == "-1" or value.lower() == "auto":
        return -1

    return human_readable_int(value)