arg_utils.py 83.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from collections.abc import Callable
11
from dataclasses import MISSING, dataclass, fields, is_dataclass
12
from itertools import permutations
13
from types import UnionType
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Literal,
19
    TypeAlias,
20
21
22
23
24
25
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
26

27
import huggingface_hub
28
import regex as re
29
import torch
30
from pydantic import TypeAdapter, ValidationError
31
from pydantic.fields import FieldInfo
32
from typing_extensions import TypeIs, deprecated
33

34
import vllm.envs as envs
35
from vllm.attention.backends.registry import _Backend
36
37
38
39
40
41
42
43
44
45
46
from vllm.config import (
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
    EPLBConfig,
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
    ModelConfig,
47
    MultiModalConfig,
48
49
50
51
52
53
54
55
56
    ObservabilityConfig,
    ParallelConfig,
    PoolerConfig,
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
    VllmConfig,
    get_attr_docs,
)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo
from vllm.config.device import Device
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
    TaskOption,
    TokenizerMode,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
72
from vllm.config.utils import get_field
73
from vllm.logger import init_logger
74
from vllm.platforms import CpuArchEnum, current_platform
75
from vllm.plugins import load_general_plugins
76
from vllm.ray.lazy_utils import is_ray_initialized
77
from vllm.reasoning import ReasoningParserManager
78
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
79
80
81
82
83
from vllm.transformers_utils.config import (
    get_model_path,
    is_interleaved,
    maybe_override_with_speculators,
)
84
from vllm.transformers_utils.utils import check_gguf_file
85
86
from vllm.utils import is_in_ray_actor
from vllm.utils.argparse_utils import FlexibleArgumentParser
87
from vllm.utils.mem_constants import GiB_bytes
88
from vllm.utils.network_utils import get_ip
89
from vllm.v1.sample.logits_processor import LogitsProcessor
90

91
92
if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
93
    from vllm.model_executor.model_loader import LoadFormats
94
    from vllm.usage.usage_lib import UsageContext
95
    from vllm.v1.executor import Executor
96
else:
97
    Executor = Any
98
    QuantizationMethods = Any
99
    LoadFormats = Any
100
101
    UsageContext = Any

102
103
logger = init_logger(__name__)

104
105
# object is used to allow for special typing forms
T = TypeVar("T")
106
107
TypeHint: TypeAlias = type[Any] | object
TypeHintT: TypeAlias = type[T] | object
108

109

110
111
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
112
113
114
115
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
116
117
                f"Value {val} cannot be converted to {return_type}."
            ) from e
118

119
120
121
    return _parse_type


122
123
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]:
    def _optional_type(val: str) -> T | None:
124
125
126
127
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

128
    return _optional_type
129
130


131
def union_dict_and_str(val: str) -> str | dict[str, str] | None:
132
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
133
        return str(val)
134
    return optional_type(json.loads)(val)
135
136


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


152
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
153
154
155
156
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
157
    type_hint = get_type(type_hints, Literal)
158
159
160
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
161
        raise ValueError(
162
            "All options must be of the same type. "
163
164
            f"Got {options} with types {[type(c) for c in options]}"
        )
165
166
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
167
168


169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]:
    type_hint = get_type(type_hints, type)
    types = get_args(type_hint)
    elem_type = types[0]

    # Handle Ellipsis
    assert all(t is elem_type for t in types if t is not Ellipsis), (
        f"All non-Ellipsis elements must be of the same type. Got {types}."
    )

    # Handle Union types
    if get_origin(elem_type) in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
        assert str in get_args(elem_type), (
            "If element can have multiple types, one must be 'str' "
            f"(i.e. 'list[int | str]'). Got {elem_type}."
        )
        elem_type = str

    return {
        "type": elem_type,
        "nargs": "+" if type is not tuple or Ellipsis in types else len(types),
    }


194
195
196
197
198
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


199
200
201
202
203
204
205
206
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
207
208
    elif origin in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
209
210
211
212
213
214
215
216
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


217
218
219
220
def is_online_quantization(quantization: Any) -> bool:
    return quantization in ["inc"]


221
NEEDS_HELP = (
222
223
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
224
225
226
227
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


228
@functools.lru_cache(maxsize=30)
229
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
230
231
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
232
233
    kwargs = {}
    for field in fields(cls):
234
        # Get the set of possible types for the field
235
        type_hints: set[TypeHint] = get_type_hints(field.type)
236
237
238
239
240

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

241
        # Get the default value of the field
242
243
        if field.default is not MISSING:
            default = field.default
244
245
246
247
248
249
250
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
                default = (
                    default.default
                    if default.default_factory is None
                    else default.default_factory()
                )
251
        elif field.default_factory is not MISSING:
252
            default = field.default_factory()
253
254
255

        # Get the help text for the field
        name = field.name
256
        help = cls_docs.get(name, "").strip()
257
258
259
260
261
262
263
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
264
265
266
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
267
        if dataclass_cls is not None:
268
269
270
271
272
273
274
275

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
276
            kwargs[name]["help"] += f"\n\n{json_tip}"
277
        elif contains_type(type_hints, bool):
278
279
280
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
281
            kwargs[name].update(literal_to_kwargs(type_hints))
282
        elif contains_type(type_hints, tuple):
283
            kwargs[name].update(collection_to_kwargs(type_hints, tuple))
284
        elif contains_type(type_hints, list):
285
286
287
            kwargs[name].update(collection_to_kwargs(type_hints, list))
        elif contains_type(type_hints, set):
            kwargs[name].update(collection_to_kwargs(type_hints, set))
288
289
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
290
            # Special case for large integers
291
292
293
294
295
296
            human_readable_ints = {
                "max_model_len",
                "max_num_batched_tokens",
                "kv_cache_memory_bytes",
            }
            if name in human_readable_ints:
297
                kwargs[name]["type"] = human_readable_int
298
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
299
300
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
301
302
303
304
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
305
            kwargs[name]["type"] = union_dict_and_str
306
        elif contains_type(type_hints, dict):
307
            kwargs[name]["type"] = parse_type(json.loads)
308
            kwargs[name]["help"] += f"\n\n{json_tip}"
309
310
311
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
312
313
            kwargs[name]["type"] = str
        else:
314
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
315

316
317
318
319
320
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

321
322
323
324
325
326
327
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
328
329


330
def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
331
332
    """Return argparse kwargs for the given Config dataclass.

333
334
335
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

336
337
338
339
340
341
342
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


343
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
344
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
345
    """Arguments for vLLM engine."""
346

347
    model: str = ModelConfig.model
348
349
350
    served_model_name: str | list[str] | None = ModelConfig.served_model_name
    tokenizer: str | None = ModelConfig.tokenizer
    hf_config_path: str | None = ModelConfig.hf_config_path
351
352
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
353
    task: TaskOption | None = ModelConfig.task
354
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
355
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
356
357
358
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
359
360
    allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
    download_dir: str | None = LoadConfig.download_dir
361
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
362
    load_format: str | LoadFormats = LoadConfig.load_format
363
364
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
365
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
366
367
    seed: int | None = ModelConfig.seed
    max_model_len: int | None = ModelConfig.max_model_len
368
369
370
371
372
373
374
    cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes
    cudagraph_capture_sizes: list[int] | None = (
        CompilationConfig.cudagraph_capture_sizes
    )
    max_cudagraph_capture_size: int | None = get_field(
        CompilationConfig, "max_cudagraph_capture_size"
    )
375
376
377
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
378
    distributed_executor_backend: (
379
        str | DistributedExecutorBackend | type[Executor] | None
380
    ) = ParallelConfig.distributed_executor_backend
381
    # number of P/D disaggregation (or other disaggregation) workers
382
383
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
384
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
385
    data_parallel_size: int = ParallelConfig.data_parallel_size
386
387
388
389
390
    data_parallel_rank: int | None = None
    data_parallel_start_rank: int | None = None
    data_parallel_size_local: int | None = None
    data_parallel_address: str | None = None
    data_parallel_rpc_port: int | None = None
391
    data_parallel_hybrid_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
392
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
393
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
394
    all2all_backend: str | None = ParallelConfig.all2all_backend
395
    enable_dbo: bool = ParallelConfig.enable_dbo
396
397
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
398
399
400
    disable_nccl_for_dp_synchronization: bool = (
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
401
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
402
    enable_eplb: bool = ParallelConfig.enable_eplb
403
    expert_placement_strategy: ExpertPlacementStrategy = (
404
        ParallelConfig.expert_placement_strategy
405
    )
406
407
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
408
409
410
411
    num_redundant_experts: int = EPLBConfig.num_redundant_experts
    eplb_window_size: int = EPLBConfig.window_size
    eplb_step_interval: int = EPLBConfig.step_interval
    eplb_log_balancedness: bool = EPLBConfig.log_balancedness
412
    max_parallel_loading_workers: int | None = (
413
414
        ParallelConfig.max_parallel_loading_workers
    )
415
416
    block_size: BlockSize | None = CacheConfig.block_size
    enable_prefix_caching: bool | None = CacheConfig.enable_prefix_caching
417
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
418
        CacheConfig.prefix_caching_hash_algo
419
    )
420
421
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
422
423
424
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
425
426
    kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
    max_num_batched_tokens: int | None = SchedulerConfig.max_num_batched_tokens
427
428
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
429
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
430
    max_num_seqs: int | None = SchedulerConfig.max_num_seqs
431
    max_logprobs: int = ModelConfig.max_logprobs
432
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
433
    disable_log_stats: bool = False
434
    aggregate_engine_logging: bool = False
435
436
    revision: str | None = ModelConfig.revision
    code_revision: str | None = ModelConfig.code_revision
437
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
438
439
    rope_theta: float | None = ModelConfig.rope_theta
    hf_token: bool | str | None = ModelConfig.hf_token
440
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
441
442
    tokenizer_revision: str | None = ModelConfig.tokenizer_revision
    quantization: QuantizationMethods | None = ModelConfig.quantization
443
    enforce_eager: bool = ModelConfig.enforce_eager
444
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
445
    limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
446
447
        MultiModalConfig, "limit_per_prompt"
    )
448
    enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
449
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
450
451
452
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
453
    mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
454
    disable_mm_preprocessor_cache: bool = False  # DEPRECATED
455
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
456
    mm_processor_cache_type: MMCacheType | None = (
457
        MultiModalConfig.mm_processor_cache_type
458
459
    )
    mm_shm_cache_max_object_size_mb: int = (
460
        MultiModalConfig.mm_shm_cache_max_object_size_mb
461
    )
462
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
463
464
465
    mm_encoder_attn_backend: _Backend | str | None = (
        MultiModalConfig.mm_encoder_attn_backend
    )
466
    io_processor_plugin: str | None = None
467
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
468
    video_pruning_rate: float = MultiModalConfig.video_pruning_rate
469
    # LoRA fields
470
    enable_lora: bool = False
471
472
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
473
    default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras
474
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
475
476
    max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
    lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
477
478
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size

479
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
480
    num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
481
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
482
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
483
    ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
484

485
    enable_chunked_prefill: bool | None = SchedulerConfig.enable_chunked_prefill
486
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
487

488
    disable_hybrid_kv_cache_manager: bool = (
489
490
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
491

492
    structured_outputs_config: StructuredOutputsConfig = get_field(
493
494
        VllmConfig, "structured_outputs_config"
    )
495
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
496

497
    # Deprecated guided decoding fields
498
499
500
501
    guided_decoding_backend: str | None = None
    guided_decoding_disable_fallback: bool | None = None
    guided_decoding_disable_any_whitespace: bool | None = None
    guided_decoding_disable_additional_properties: bool | None = None
502

503
    logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
504

505
    speculative_config: dict[str, Any] | None = None
506

507
    show_hidden_metrics_for_version: str | None = (
508
        ObservabilityConfig.show_hidden_metrics_for_version
509
    )
510
511
    otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: list[DetailedTraceModules] | None = (
512
        ObservabilityConfig.collect_detailed_traces
513
    )
514
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
515
    scheduler_cls: str | type[object] = SchedulerConfig.scheduler_cls
516

517
518
    pooler_config: PoolerConfig | None = ModelConfig.pooler_config
    override_pooler_config: dict | PoolerConfig | None = (
519
        ModelConfig.override_pooler_config
520
521
    )
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
522
523
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
524

525
526
    kv_transfer_config: KVTransferConfig | None = None
    kv_events_config: KVEventsConfig | None = None
527

528
529
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
530
531
532
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
533
    model_impl: str = ModelConfig.model_impl
534
    override_attention_dtype: str = ModelConfig.override_attention_dtype
535

536
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
537
538
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
539

540
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
541

542
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
543
    pt_load_map_location: str = LoadConfig.pt_load_map_location
544

545
546
    # DEPRECATED
    enable_multimodal_encoder_data_parallel: bool = False
547

548
    logits_processors: list[str | type[LogitsProcessor]] | None = (
549
550
        ModelConfig.logits_processors
    )
551
552
    """Custom logitproc types"""

553
554
    async_scheduling: bool = SchedulerConfig.async_scheduling

555
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
556

557
    def __post_init__(self):
558
559
560
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
561
        if isinstance(self.compilation_config, dict):
562
            self.compilation_config = CompilationConfig(**self.compilation_config)
563
        if isinstance(self.eplb_config, dict):
564
            self.eplb_config = EPLBConfig(**self.eplb_config)
565
        # Setup plugins
566
        from vllm.plugins import load_general_plugins
567

568
        load_general_plugins()
569
570
571
572
573
        # when use hf offline,replace model id to local model path
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
            logger.info(
574
575
576
577
                "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                model_id,
                self.model,
            )
578
579

    @staticmethod
580
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
581
        """Shared CLI arguments for vLLM engine."""
582

583
        # Model arguments
584
585
586
587
588
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
589
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
590
            model_group.add_argument("--model", **model_kwargs["model"])
591
592
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
593
        model_group.add_argument("--task", **model_kwargs["task"], deprecated=True)
594
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
595
596
597
598
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
599
600
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
601
602
603
604
605
606
607
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
608
        model_group.add_argument("--revision", **model_kwargs["revision"])
609
610
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"])
611
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
636
637
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--override-pooler-config",
            **model_kwargs["override_pooler_config"],
            deprecated=True,
        )
        model_group.add_argument(
            "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
        )
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
665
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
666
667
668
669
670
671
672
673
674
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
        )
675

676
677
678
679
680
681
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
682
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
683
684
685
686
687
688
689
690
691
692
693
694
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
695

696
697
698
699
700
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
701
        )
702
        structured_outputs_group.add_argument(
703
            "--reasoning-parser",
704
            # This choice is a special case because it's not static
705
            choices=list(ReasoningParserManager.reasoning_parsers),
706
707
            **structured_outputs_kwargs["reasoning_parser"],
        )
708
709
710
711
712
713
714
715
716
717
718
        # Deprecated guided decoding arguments
        for arg, type in [
            ("--guided-decoding-backend", str),
            ("--guided-decoding-disable-fallback", bool),
            ("--guided-decoding-disable-any-whitespace", bool),
            ("--guided-decoding-disable-additional-properties", bool),
        ]:
            structured_outputs_group.add_argument(
                arg,
                type=type,
                help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
719
720
                deprecated=True,
            )
721

722
        # Parallel arguments
723
724
725
726
727
728
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
729
            "--distributed-executor-backend",
730
731
            **parallel_kwargs["distributed_executor_backend"],
        )
732
        parallel_group.add_argument(
733
734
735
736
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
737
        parallel_group.add_argument(
738
739
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
740
        parallel_group.add_argument(
741
742
743
744
745
746
747
748
749
750
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
751
            type=int,
752
753
754
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
755
        parallel_group.add_argument(
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
786
        parallel_group.add_argument(
787
788
789
790
791
            "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
        )
        parallel_group.add_argument(
            "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
        )
792
793
794
        parallel_group.add_argument(
            "--all2all-backend", **parallel_kwargs["all2all_backend"]
        )
795
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
796
797
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
798
799
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
800
801
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
802
803
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
804
805
806
807
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
808
809
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
810
811
        parallel_group.add_argument(
            "--expert-placement-strategy",
812
813
            **parallel_kwargs["expert_placement_strategy"],
        )
814
815
816
        parallel_group.add_argument(
            "--num-redundant-experts",
            type=int,
817
818
819
            help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
            deprecated=True,
        )
820
821
822
823
        parallel_group.add_argument(
            "--eplb-window-size",
            type=int,
            help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
824
825
            deprecated=True,
        )
826
827
828
        parallel_group.add_argument(
            "--eplb-step-interval",
            type=int,
829
830
831
            help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
            deprecated=True,
        )
832
833
834
        parallel_group.add_argument(
            "--eplb-log-balancedness",
            action=argparse.BooleanOptionalAction,
835
836
837
            help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
            deprecated=True,
        )
838

839
        parallel_group.add_argument(
840
            "--max-parallel-loading-workers",
841
842
            **parallel_kwargs["max_parallel_loading_workers"],
        )
843
        parallel_group.add_argument(
844
845
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
846
        parallel_group.add_argument(
847
            "--disable-custom-all-reduce",
848
849
850
851
852
853
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
854
855
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
856
            action="store_true",
857
858
            deprecated=True,
        )
859

860
861
862
863
864
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
865
        )
866
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
867
868
869
870
871
872
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
873
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
            "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"]
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
897

898
        # Multimodal related configs
899
900
901
902
903
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
904
        multimodal_group.add_argument(
905
906
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
907
908
909
        multimodal_group.add_argument(
            "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
        )
910
911
912
913
914
915
916
917
918
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
919
        multimodal_group.add_argument(
920
921
            "--disable-mm-preprocessor-cache", action="store_true", deprecated=True
        )
922
        multimodal_group.add_argument(
923
924
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
925
926
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
927
928
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
929
        multimodal_group.add_argument(
930
931
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
932
933
934
935
        multimodal_group.add_argument(
            "--mm-encoder-attn-backend",
            **multimodal_kwargs["mm_encoder_attn_backend"],
        )
936
937
938
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
939
        multimodal_group.add_argument(
940
941
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
942

943
        multimodal_group.add_argument(
944
945
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
946

947
        # LoRA related configs
948
949
950
951
952
953
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
954
            "--enable-lora",
955
            action=argparse.BooleanOptionalAction,
956
957
            help="If True, enable handling of LoRA adapters.",
        )
958
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
959
960
961
962
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
        lora_group.add_argument(
            "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"]
        )
963
        lora_group.add_argument(
964
            "--lora-dtype",
965
966
            **lora_kwargs["lora_dtype"],
        )
967
968
969
970
971
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
972

973
974
975
976
977
978
979
980
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
981
982
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
983
        observability_group.add_argument(
984
985
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
986
987
988
989
990
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
991
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
992
993
994
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
995
996
            **observability_kwargs["collect_detailed_traces"],
        )
997

998
999
1000
1001
1002
1003
1004
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
1005
1006
            "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"]
        )
1007
        scheduler_group.add_argument(
1008
1009
1010
1011
1012
            "--max-num-seqs", **scheduler_kwargs["max_num_seqs"]
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
1013
1014
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
1015
1016
            **scheduler_kwargs["max_long_partial_prefills"],
        )
1017
1018
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
1019
1020
1021
1022
1023
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
        scheduler_group.add_argument(
            "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"]
        )
1024
1025
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
1026
        scheduler_group.add_argument(
1027
1028
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
1029
        scheduler_group.add_argument(
1030
1031
1032
1033
1034
1035
1036
1037
            "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"]
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1038
1039
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1040
1041
1042
1043
1044
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1045

1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
        # Compilation arguments
        compilation_kwargs = get_kwargs(CompilationConfig)
        compilation_group = parser.add_argument_group(
            title="CompilationConfig",
            description=CompilationConfig.__doc__,
        )
        compilation_group.add_argument(
            "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
        )
        compilation_kwargs["cudagraph_capture_sizes"]["help"] = (
            "--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0,"
            " whichever is soonest. Please use --cudagraph-capture-sizes instead."
        )
        compilation_group.add_argument(
            "--cuda-graph-sizes",
            **compilation_kwargs["cudagraph_capture_sizes"],
            deprecated=True,
        )
        compilation_group.add_argument(
            "--max-cudagraph-capture-size",
            **compilation_kwargs["max_cudagraph_capture_size"],
        )

1069
        # vLLM arguments
1070
        vllm_kwargs = get_kwargs(VllmConfig)
1071
1072
1073
1074
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1075
1076
1077
1078
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument(
            "--compilation-config", "-O", **vllm_kwargs["compilation_config"]
        )
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1095

1096
        # Other arguments
1097
1098
1099
1100
1101
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1102

1103
1104
1105
1106
1107
1108
        parser.add_argument(
            "--aggregate-engine-logging",
            action="store_true",
            help="Log aggregate rather than per-engine statistics "
            "when using data parallelism.",
        )
1109
        return parser
1110
1111

    @classmethod
1112
    def from_cli_args(cls, args: argparse.Namespace):
1113
1114
1115
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1116
1117
1118
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1119
        return engine_args
1120

1121
    def create_model_config(self) -> ModelConfig:
1122
1123
1124
1125
1126
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
1127
1128
1129
1130
1131
1132
        if (
            not isinstance(self, AsyncEngineArgs)
            and envs.VLLM_CI_USE_S3
            and self.model in MODELS_ON_S3
            and self.load_format == "auto"
        ):
1133
1134
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"

1135
1136
1137
1138
        if self.disable_mm_preprocessor_cache:
            logger.warning(
                "`--disable-mm-preprocessor-cache` is deprecated "
                "and will be removed in v0.13. "
1139
1140
                "Please use `--mm-processor-cache-gb 0` instead.",
            )
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152

            self.mm_processor_cache_gb = 0
        elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
            logger.warning(
                "VLLM_MM_INPUT_CACHE_GIB` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-processor-cache-gb %d` instead.",
                envs.VLLM_MM_INPUT_CACHE_GIB,
            )

            self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB

1153
1154
1155
1156
        if self.enable_multimodal_encoder_data_parallel:
            logger.warning(
                "--enable-multimodal-encoder-data-parallel` is deprecated "
                "and will be removed in v0.13. "
1157
1158
                "Please use `--mm-encoder-tp-mode data` instead."
            )
1159
1160
1161

            self.mm_encoder_tp_mode = "data"

1162
        return ModelConfig(
1163
            model=self.model,
1164
            hf_config_path=self.hf_config_path,
1165
1166
            runner=self.runner,
            convert=self.convert,
1167
            task=self.task,
1168
            tokenizer=self.tokenizer,
1169
1170
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1171
            allowed_local_media_path=self.allowed_local_media_path,
1172
            allowed_media_domains=self.allowed_media_domains,
1173
1174
1175
1176
1177
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1178
            rope_theta=self.rope_theta,
1179
            hf_token=self.hf_token,
1180
            hf_overrides=self.hf_overrides,
1181
1182
1183
1184
1185
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_logprobs=self.max_logprobs,
1186
            logprobs_mode=self.logprobs_mode,
1187
            disable_sliding_window=self.disable_sliding_window,
1188
            disable_cascade_attn=self.disable_cascade_attn,
1189
            skip_tokenizer_init=self.skip_tokenizer_init,
1190
            enable_prompt_embeds=self.enable_prompt_embeds,
1191
            served_model_name=self.served_model_name,
1192
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1193
            enable_mm_embeds=self.enable_mm_embeds,
1194
            interleave_mm_strings=self.interleave_mm_strings,
1195
            media_io_kwargs=self.media_io_kwargs,
1196
            skip_mm_profiling=self.skip_mm_profiling,
1197
            config_format=self.config_format,
1198
            mm_processor_kwargs=self.mm_processor_kwargs,
1199
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1200
            mm_processor_cache_type=self.mm_processor_cache_type,
1201
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1202
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1203
            mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1204
            pooler_config=self.pooler_config,
1205
            override_pooler_config=self.override_pooler_config,
1206
            logits_processor_pattern=self.logits_processor_pattern,
1207
            generation_config=self.generation_config,
1208
            override_generation_config=self.override_generation_config,
1209
            enable_sleep_mode=self.enable_sleep_mode,
1210
            model_impl=self.model_impl,
1211
            override_attention_dtype=self.override_attention_dtype,
1212
            logits_processors=self.logits_processors,
1213
            video_pruning_rate=self.video_pruning_rate,
1214
            io_processor_plugin=self.io_processor_plugin,
1215
        )
1216

1217
    def validate_tensorizer_args(self):
1218
1219
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1220
1221
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1222
1223
1224
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1225

1226
    def create_load_config(self) -> LoadConfig:
1227
1228
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1229

1230
1231
1232
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1233
1234
                    self.model_loader_extra_config.to_serializable()
                )
1235
            self.model_loader_extra_config["tensorizer_config"] = {}
1236
1237
1238
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1239
            self.validate_tensorizer_args()
1240

1241
1242
1243
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1244
            safetensors_load_strategy=self.safetensors_load_strategy,
1245
            device="cpu" if is_online_quantization(self.quantization) else None,
1246
1247
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1248
            use_tqdm_on_load=self.use_tqdm_on_load,
1249
            pt_load_map_location=self.pt_load_map_location,
1250
        )
1251

1252
1253
1254
1255
1256
1257
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
1258
    ) -> SpeculativeConfig | None:
1259
1260
1261
1262
1263
1264
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1265
        dictionary from the engine.
1266
1267
        """
        if self.speculative_config is None:
1268
            return None
1269

1270
1271
1272
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1273
1274
1275
1276
1277
1278
1279
1280
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
                "enable_chunked_prefill": enable_chunked_prefill,
                "disable_log_stats": disable_log_stats,
            }
        )
1281
        return SpeculativeConfig(**self.speculative_config)
1282

1283
1284
    def create_engine_config(
        self,
1285
        usage_context: UsageContext | None = None,
1286
        headless: bool = False,
1287
1288
1289
1290
1291
1292
1293
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1294

1295
1296
1297
1298
1299
1300
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1301
        current_platform.pre_register_and_update()
1302

1303
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1304

1305
1306
1307
1308
        model_config = self.create_model_config()
        self.model = model_config.model
        self.tokenizer = model_config.tokenizer

1309
1310
1311
1312
1313
1314
1315
1316
1317
        (self.model, self.tokenizer, self.speculative_config) = (
            maybe_override_with_speculators(
                model=self.model,
                tokenizer=self.tokenizer,
                revision=self.revision,
                trust_remote_code=self.trust_remote_code,
                vllm_speculative_config=self.speculative_config,
            )
        )
1318

1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

1336
1337
        # Set default arguments for V1 Engine.
        self._set_default_args(usage_context, model_config)
1338
1339
        # Disable chunked prefill and prefix caching for:
        # POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.S390X,
            CpuArchEnum.ARM,
            CpuArchEnum.RISCV,
        ):
            logger.info(
                "Chunked prefill is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
1351
            self.enable_chunked_prefill = False
1352
1353
1354
1355
1356
1357
1358
            logger.info(
                "Prefix caching is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
            self.enable_prefix_caching = False

1359
1360
        assert self.enable_chunked_prefill is not None

1361
        sliding_window: int | None = None
1362
1363
1364
1365
1366
1367
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1368
1369
1370
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1371
        # reuses the GPUs of TP group, and split one TP group into
1372
        # tp_size//dcp_size DCP groups.
1373
        assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, (
1374
1375
1376
1377
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1378
        cache_config = CacheConfig(
1379
            block_size=self.block_size,
1380
            gpu_memory_utilization=self.gpu_memory_utilization,
1381
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1382
1383
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1384
            is_attention_free=model_config.is_attention_free,
1385
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1386
            sliding_window=sliding_window,
1387
            enable_prefix_caching=self.enable_prefix_caching,
1388
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1389
            cpu_offload_gb=self.cpu_offload_gb,
1390
            calculate_kv_scales=self.calculate_kv_scales,
1391
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1392
1393
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1394
        )
1395

1396
1397
1398
1399
1400
1401
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1402

1403
            ray_runtime_env = ray.get_runtime_context().runtime_env
1404
1405
1406
1407
1408
1409
1410
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1411

1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1423
        assert not headless or not self.data_parallel_hybrid_lb, (
1424
1425
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1426

1427
        data_parallel_external_lb = self.data_parallel_rank is not None
1428
        # Local DP rank = 1, use pure-external LB.
1429
1430
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
1431
1432
                "data_parallel_size_local must be 1 when data_parallel_rank is set"
            )
1433
            data_parallel_size_local = 1
1434
1435
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1436
1437
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

            self.data_parallel_rank = self.data_parallel_start_rank or 0
1453
        else:
1454
            assert not self.data_parallel_hybrid_lb, (
1455
1456
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1457

1458
1459
1460
1461
1462
1463
1464
1465
1466
            if self.data_parallel_backend == "ray" and (
                envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
            ):
                # Data parallel size defaults to 1 if DP ranks are spanning
                # multiple nodes
                data_parallel_size_local = 1
            else:
                # Otherwise local DP size defaults to global DP size if not set
                data_parallel_size_local = self.data_parallel_size
1467
1468
1469

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1470
1471
1472
1473
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1474
1475
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1476
1477
1478
1479
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1480
1481
                    self.data_parallel_backend,
                )
Rui Qiao's avatar
Rui Qiao committed
1482
1483
1484
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1485
1486
1487

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1488
        data_parallel_rpc_port = (
1489
            self.data_parallel_rpc_port
1490
1491
1492
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1493

1494
1495
        if self.async_scheduling:
            if self.pipeline_parallel_size > 1:
1496
1497
1498
                raise ValueError(
                    "Async scheduling is not supported with pipeline-parallel-size > 1."
                )
1499
1500
1501
1502
1503
1504

            # Currently, async scheduling does not support speculative decoding.
            # TODO(woosuk): Support it.
            if self.speculative_config is not None:
                raise ValueError(
                    "Currently, speculative decoding is not supported with "
1505
1506
                    "async scheduling."
                )
1507

1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
        # Forward the deprecated CLI args to the EPLB config.
        if self.num_redundant_experts is not None:
            self.eplb_config.num_redundant_experts = self.num_redundant_experts
        if self.eplb_window_size is not None:
            self.eplb_config.window_size = self.eplb_window_size
        if self.eplb_step_interval is not None:
            self.eplb_config.step_interval = self.eplb_step_interval
        if self.eplb_log_balancedness is not None:
            self.eplb_config.log_balancedness = self.eplb_log_balancedness

1518
        parallel_config = ParallelConfig(
1519
1520
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1521
            data_parallel_size=self.data_parallel_size,
1522
1523
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1524
1525
1526
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1527
            data_parallel_backend=self.data_parallel_backend,
1528
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1529
            enable_expert_parallel=self.enable_expert_parallel,
1530
            all2all_backend=self.all2all_backend,
1531
1532
            enable_dbo=self.enable_dbo,
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1533
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1534
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1535
            enable_eplb=self.enable_eplb,
1536
            eplb_config=self.eplb_config,
1537
            expert_placement_strategy=self.expert_placement_strategy,
1538
1539
1540
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1541
            ray_runtime_env=ray_runtime_env,
1542
            placement_group=placement_group,
1543
1544
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1545
            worker_extension_cls=self.worker_extension_cls,
1546
            decode_context_parallel_size=self.decode_context_parallel_size,
1547
1548
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1549
        )
1550

1551
1552
1553
1554
1555
1556
1557
1558
1559
        if self.async_scheduling and (
            parallel_config.distributed_executor_backend not in ("mp", "uni")
        ):
            raise ValueError(
                "Currently, async scheduling only supports `mp` or `uni` "
                "distributed executor backend, but you choose "
                f"`{parallel_config.distributed_executor_backend}`."
            )

1560
        speculative_config = self.create_speculative_config(
1561
1562
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1563
            enable_chunked_prefill=self.enable_chunked_prefill,
1564
            disable_log_stats=self.disable_log_stats,
1565
1566
        )

1567
1568
1569
1570
1571
        # make sure num_lookahead_slots is set appropriately depending on
        # whether speculative decoding is enabled
        num_lookahead_slots = self.num_lookahead_slots
        if speculative_config is not None:
            num_lookahead_slots = speculative_config.num_lookahead_slots
1572

1573
        scheduler_config = SchedulerConfig(
1574
            runner_type=model_config.runner_type,
1575
1576
1577
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1578
            num_lookahead_slots=num_lookahead_slots,
1579
            enable_chunked_prefill=self.enable_chunked_prefill,
1580
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1581
            is_multimodal_model=model_config.is_multimodal_model,
1582
            is_encoder_decoder=model_config.is_encoder_decoder,
1583
            policy=self.scheduling_policy,
1584
            scheduler_cls=self.scheduler_cls,
1585
1586
1587
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1588
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1589
            async_scheduling=self.async_scheduling,
1590
        )
1591

1592
1593
1594
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_extra_vocab_size=self.lora_extra_vocab_size,
                lora_dtype=self.lora_dtype,
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1613

1614
1615
1616
1617
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1618
        load_config = self.create_load_config()
1619

1620
1621
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1622
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1623
1624
1625
1626

        # Forward the deprecated CLI args to the StructuredOutputsConfig
        so_config = self.structured_outputs_config
        if self.guided_decoding_backend is not None:
1627
            so_config.guided_decoding_backend = self.guided_decoding_backend
1628
        if self.guided_decoding_disable_fallback is not None:
1629
            so_config.disable_fallback = self.guided_decoding_disable_fallback
1630
        if self.guided_decoding_disable_any_whitespace is not None:
1631
            so_config.disable_any_whitespace = (
1632
1633
                self.guided_decoding_disable_any_whitespace
            )
1634
        if self.guided_decoding_disable_additional_properties is not None:
1635
            so_config.disable_additional_properties = (
1636
1637
                self.guided_decoding_disable_additional_properties
            )
1638

1639
        observability_config = ObservabilityConfig(
1640
            show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version),
1641
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1642
            collect_detailed_traces=self.collect_detailed_traces,
1643
        )
1644

1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
        # Compilation config overrides
        if self.cuda_graph_sizes is not None:
            logger.warning(
                "--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or "
                "v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes "
                "instead."
            )
            if self.compilation_config.cudagraph_capture_sizes is not None:
                raise ValueError(
                    "cuda_graph_sizes and compilation_config."
                    "cudagraph_capture_sizes are mutually exclusive"
                )
            self.compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes
        if self.cudagraph_capture_sizes is not None:
            if self.compilation_config.cudagraph_capture_sizes is not None:
                raise ValueError(
                    "cudagraph_capture_sizes and compilation_config."
                    "cudagraph_capture_sizes are mutually exclusive"
                )
            self.compilation_config.cudagraph_capture_sizes = (
                self.cudagraph_capture_sizes
            )
        if self.max_cudagraph_capture_size is not None:
            if self.compilation_config.max_cudagraph_capture_size is not None:
                raise ValueError(
                    "max_cudagraph_capture_size and compilation_config."
                    "max_cudagraph_capture_size are mutually exclusive"
                )
            self.compilation_config.max_cudagraph_capture_size = (
                self.max_cudagraph_capture_size
            )

1677
        config = VllmConfig(
1678
1679
1680
1681
1682
1683
1684
1685
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
1686
            structured_outputs_config=self.structured_outputs_config,
1687
            observability_config=observability_config,
1688
            compilation_config=self.compilation_config,
1689
            kv_transfer_config=self.kv_transfer_config,
1690
            kv_events_config=self.kv_events_config,
1691
            additional_config=self.additional_config,
1692
        )
1693

1694
1695
        return config

1696
1697
1698
1699
1700
1701
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1702
1703
1704
1705
        if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
            _raise_or_fallback(
                feature_name="--logits-processor-pattern", recommend_to_remove=False
            )
1706
1707
1708
            return False

        # No Concurrent Partial Prefills so far.
1709
1710
1711
1712
1713
1714
1715
1716
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
            _raise_or_fallback(
                feature_name="Concurrent Partial Prefill", recommend_to_remove=False
            )
1717
1718
            return False

1719
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
        if self.speculative_config is not None:
            # speculative_config could still be a dict at this point
            if isinstance(self.speculative_config, dict):
                method = self.speculative_config.get("method", None)
            else:
                method = self.speculative_config.method

            if method == "draft_model":
                raise NotImplementedError(
                    "Draft model speculative decoding is not supported yet. "
                    "Please consider using other speculative decoding methods "
1731
1732
                    "such as ngram, medusa, eagle, or mtp."
                )
1733
1734

        V1_BACKENDS = [
1735
1736
            "FLASH_ATTN",
            "PALLAS",
1737
            "TRITON_ATTN",
1738
            "TRITON_MLA",
1739
            "CUTLASS_MLA",
1740
            "FLASHMLA",
1741
            "FLASH_ATTN_MLA",
1742
            "FLASHINFER",
1743
            "FLASHINFER_MLA",
1744
            "ROCM_AITER_MLA",
1745
            "TORCH_SDPA",
1746
            "FLEX_ATTENTION",
1747
            "TREE_ATTN",
1748
1749
            "XFORMERS",
            "ROCM_ATTN",
1750
            "ROCM_AITER_UNIFIED_ATTN",
1751
        ]
1752
1753
1754
1755
        if (
            envs.is_set("VLLM_ATTENTION_BACKEND")
            and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS
        ):
1756
1757
1758
1759
1760
1761
1762
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

        #############################################################
        # Experimental Features - allow users to opt in.

1763
        if self.pipeline_parallel_size > 1:
1764
1765
1766
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1767
            if not supports_pp and self.distributed_executor_backend not in (
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
                _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1779
                return False
1780

1781
1782
1783
1784
        if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
            _raise_or_fallback(
                feature_name="sliding window (CPU backend)", recommend_to_remove=False
            )
1785
1786
            return False

1787
1788
1789
1790
        #############################################################

        return True

1791
1792
1793
    def _set_default_args(
        self, usage_context: UsageContext, model_config: ModelConfig
    ) -> None:
1794
        """Set Default Arguments for V1 Engine."""
1795

1796
        # V1 uses chunked prefills and prefix caching by default
1797
1798
1799
1800
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
1801

1802
            if self.enable_prefix_caching is None:
1803
1804
1805
1806
1807
1808
                # Disable prefix caching default for hybrid models
                # since the feature is still experimental.
                if model_config.is_hybrid:
                    self.enable_prefix_caching = False
                else:
                    self.enable_prefix_caching = True
1809
1810
        else:
            pooling_type = model_config.pooler_config.pooling_type
1811
            is_causal = getattr(model_config.hf_config, "is_causal", True)
1812
1813
1814
1815
1816
            incremental_prefill_supported = (
                pooling_type is not None
                and pooling_type.lower() == "last"
                and is_causal
            )
1817

1818
            action = "Enabling" if incremental_prefill_supported else "Disabling"
1819
1820
1821
1822
1823
1824
1825
1826

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)

1827
1828
        # When no user override, set the default values based on the usage
        # context.
1829
        # Use different default values for different hardware.
1830
1831
1832
1833
1834
1835
1836

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1837
            device_memory = current_platform.get_device_total_memory()
1838
            device_name = current_platform.get_device_name().lower()
1839
1840
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1841
            device_memory = 0
1842

1843
1844
1845
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
1846
        from vllm.usage.usage_lib import UsageContext
1847

1848
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1849
            # For GPUs like H100 and MI300x, use larger default values.
1850
1851
1852
1853
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1854
1855
1856
1857
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1858
1859
1860
1861
1862
1863
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1864
1865
1866
1867
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1868

1869
1870
1871
1872
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
1873
1874
1875
                    "V6E": 2048,
                    "V5E": 1024,
                    "V5P": 512,
1876
1877
                },
                UsageContext.OPENAI_API_SERVER: {
1878
1879
1880
1881
                    "V6E": 1024,
                    "V5E": 512,
                    "V5P": 256,
                },
1882
1883
            }

1884
1885
        # cpu specific default values.
        if current_platform.is_cpu():
1886
            world_size = self.pipeline_parallel_size * self.tensor_parallel_size
1887
            default_max_num_batched_tokens = {
1888
1889
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1890
1891
            }
            default_max_num_seqs = {
1892
1893
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1894
1895
            }

1896
        use_context_value = usage_context.value if usage_context else None
1897
1898
1899
1900
        if (
            self.max_num_batched_tokens is None
            and usage_context in default_max_num_batched_tokens
        ):
1901
1902
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
1903
1904
1905
1906
                if chip_name in default_max_num_batched_tokens_tpu[usage_context]:
                    self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[
                        usage_context
                    ][chip_name]
1907
                else:
1908
1909
1910
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1911
            else:
1912
1913
1914
                if not self.enable_chunked_prefill:
                    self.max_num_batched_tokens = model_config.max_model_len
                else:
1915
1916
1917
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1918
            logger.debug(
1919
                "Setting max_num_batched_tokens to %d for %s usage context.",
1920
1921
1922
                self.max_num_batched_tokens,
                use_context_value,
            )
1923

1924
1925
1926
1927
1928
        if self.max_num_seqs is None and usage_context in default_max_num_seqs:
            self.max_num_seqs = min(
                default_max_num_seqs[usage_context],
                self.max_num_batched_tokens or sys.maxsize,
            )
1929

1930
1931
1932
1933
1934
            logger.debug(
                "Setting max_num_seqs to %d for %s usage context.",
                self.max_num_seqs,
                use_context_value,
            )
1935

1936

1937
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1938
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1939
    """Arguments for asynchronous vLLM engine."""
1940

1941
1942
1943
1944
1945
1946
    enable_log_requests: bool = False

    @property
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1947
1948
        "`enable_log_requests` instead."
    )
1949
1950
1951
1952
1953
1954
1955
    def disable_log_requests(self) -> bool:
        return not self.enable_log_requests

    @disable_log_requests.setter
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1956
1957
        "`enable_log_requests` instead."
    )
1958
1959
    def disable_log_requests(self, value: bool):
        self.enable_log_requests = not value
1960
1961

    @staticmethod
1962
1963
1964
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
1965
        # Initialize plugin to update the parser, for example, The plugin may
1966
        # add a new kind of quantization method to --quantization argument or
1967
1968
        # a new device to --device argument.
        load_general_plugins()
1969
1970
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
1984
        current_platform.pre_register_and_update(parser)
1985
        return parser
1986
1987


1988
1989
1990
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
1991
1992
            f"VLLM_USE_V1=1 is not supported with {feature_name}."
        )
1993
1994
1995
1996
1997
1998
1999
2000
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


2001
2002
2003
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
2004

2005
2006
2007
2008
2009
2010
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
2011
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
2012
2013
    if match:
        decimal_multiplier = {
2014
2015
2016
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
2017
2018
        }
        binary_multiplier = {
2019
2020
2021
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
2034
2035
2036
2037
2038
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
2039
2040
2041

    # Regular plain number.
    return int(value)