arg_utils.py 89.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from collections.abc import Callable
11
from dataclasses import MISSING, dataclass, fields, is_dataclass
12
from itertools import permutations
13
from types import UnionType
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Literal,
19
    TypeAlias,
20
21
22
23
24
25
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
26

27
import huggingface_hub
28
import regex as re
29
import torch
30
from pydantic import TypeAdapter, ValidationError
31
from pydantic.fields import FieldInfo
32
from typing_extensions import TypeIs
33

34
import vllm.envs as envs
35
from vllm.config import (
36
    AttentionConfig,
37
38
39
40
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
41
    ECTransferConfig,
42
    EPLBConfig,
43
    KernelConfig,
44
45
46
47
48
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
    ModelConfig,
49
    MultiModalConfig,
50
51
52
    ObservabilityConfig,
    ParallelConfig,
    PoolerConfig,
53
    ProfilerConfig,
54
55
56
57
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
    VllmConfig,
58
    WeightTransferConfig,
59
60
    get_attr_docs,
)
61
62
63
from vllm.config.cache import (
    CacheDType,
    KVOffloadingBackend,
64
    MambaCacheMode,
65
66
67
    MambaDType,
    PrefixCachingHashAlgo,
)
68
from vllm.config.device import Device
69
from vllm.config.lora import MaxLoRARanks
70
71
72
73
74
75
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
76
    TokenizerMode,
77
78
79
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.observability import DetailedTraceModules
80
81
82
83
84
85
from vllm.config.parallel import (
    All2AllBackend,
    DataParallelBackend,
    DistributedExecutorBackend,
    ExpertPlacementStrategy,
)
86
from vllm.config.scheduler import SchedulerPolicy
87
from vllm.config.utils import get_field
88
from vllm.config.vllm import OptimizationLevel
89
from vllm.logger import init_logger, suppress_logging
90
from vllm.platforms import CpuArchEnum, current_platform
91
from vllm.plugins import load_general_plugins
92
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
93
94
95
96
from vllm.transformers_utils.config import (
    is_interleaved,
    maybe_override_with_speculators,
)
97
from vllm.transformers_utils.gguf_utils import is_gguf
98
from vllm.transformers_utils.repo_utils import get_model_path
99
from vllm.transformers_utils.utils import is_cloud_storage
100
from vllm.utils.argparse_utils import FlexibleArgumentParser
101
from vllm.utils.mem_constants import GiB_bytes
102
from vllm.utils.network_utils import get_ip
103
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
104
from vllm.v1.attention.backends.registry import AttentionBackendEnum
105
from vllm.v1.sample.logits_processor import LogitsProcessor
106

107
108
if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
109
    from vllm.model_executor.model_loader import LoadFormats
110
    from vllm.usage.usage_lib import UsageContext
111
    from vllm.v1.executor import Executor
112
else:
113
    Executor = Any
114
    QuantizationMethods = Any
115
    LoadFormats = Any
116
117
    UsageContext = Any

118

119
120
logger = init_logger(__name__)

121
122
# object is used to allow for special typing forms
T = TypeVar("T")
123
124
TypeHint: TypeAlias = type[Any] | object
TypeHintT: TypeAlias = type[T] | object
125

126

127
128
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
129
130
131
132
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
133
134
                f"Value {val} cannot be converted to {return_type}."
            ) from e
135

136
137
138
    return _parse_type


139
140
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]:
    def _optional_type(val: str) -> T | None:
141
142
143
144
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

145
    return _optional_type
146
147


148
def union_dict_and_str(val: str) -> str | dict[str, str] | None:
149
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
150
        return str(val)
151
    return optional_type(json.loads)(val)
152
153


154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


169
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
170
171
172
173
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
174
    type_hint = get_type(type_hints, Literal)
175
176
177
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
178
        raise ValueError(
179
            "All options must be of the same type. "
180
181
            f"Got {options} with types {[type(c) for c in options]}"
        )
182
183
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
184
185


186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]:
    type_hint = get_type(type_hints, type)
    types = get_args(type_hint)
    elem_type = types[0]

    # Handle Ellipsis
    assert all(t is elem_type for t in types if t is not Ellipsis), (
        f"All non-Ellipsis elements must be of the same type. Got {types}."
    )

    # Handle Union types
    if get_origin(elem_type) in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
        assert str in get_args(elem_type), (
            "If element can have multiple types, one must be 'str' "
            f"(i.e. 'list[int | str]'). Got {elem_type}."
        )
        elem_type = str

    return {
        "type": elem_type,
        "nargs": "+" if type is not tuple or Ellipsis in types else len(types),
    }


211
212
213
214
215
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


216
217
218
219
220
221
222
223
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
224
225
    elif origin in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
226
227
228
229
230
231
232
233
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


234
NEEDS_HELP = (
235
236
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
237
238
239
240
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


241
@functools.lru_cache(maxsize=30)
242
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
243
244
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
245
246
    kwargs = {}
    for field in fields(cls):
247
        # Get the set of possible types for the field
248
        type_hints: set[TypeHint] = get_type_hints(field.type)
249
250
251
252
253

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

254
        # Get the default value of the field
255
256
        if field.default is not MISSING:
            default = field.default
257
258
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
259
260
261
262
263
264
                if default.default_factory is None:
                    default = default.default
                else:
                    # VllmConfig's Fields have default_factory set to config classes.
                    # These could emit logs on init, which would be confusing.
                    with suppress_logging():
265
                        default = default.default_factory()  # type: ignore[call-arg]
266
        elif field.default_factory is not MISSING:
267
            default = field.default_factory()
268
269
270

        # Get the help text for the field
        name = field.name
271
        help = cls_docs.get(name, "").strip()
272
273
274
275
276
277
278
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
279
280
281
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
282
        if dataclass_cls is not None:
283
284
285
286
287
288
289
290

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
291
            kwargs[name]["help"] += f"\n\n{json_tip}"
292
        elif contains_type(type_hints, bool):
293
294
295
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
296
            kwargs[name].update(literal_to_kwargs(type_hints))
297
        elif contains_type(type_hints, tuple):
298
            kwargs[name].update(collection_to_kwargs(type_hints, tuple))
299
        elif contains_type(type_hints, list):
300
301
302
            kwargs[name].update(collection_to_kwargs(type_hints, list))
        elif contains_type(type_hints, set):
            kwargs[name].update(collection_to_kwargs(type_hints, set))
303
        elif contains_type(type_hints, int):
304
305
306
307
            if name == "max_model_len":
                kwargs[name]["type"] = human_readable_int_or_auto
                kwargs[name]["help"] += f"\n\n{human_readable_int_or_auto.__doc__}"
            elif name in ("max_num_batched_tokens", "kv_cache_memory_bytes"):
308
                kwargs[name]["type"] = human_readable_int
309
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
310
311
            else:
                kwargs[name]["type"] = int
312
313
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
314
315
316
317
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
318
            kwargs[name]["type"] = union_dict_and_str
319
        elif contains_type(type_hints, dict):
320
            kwargs[name]["type"] = parse_type(json.loads)
321
            kwargs[name]["help"] += f"\n\n{json_tip}"
322
323
324
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
325
326
            kwargs[name]["type"] = str
        else:
327
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
328

329
330
331
332
333
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

334
335
336
337
338
339
340
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
341
342


343
def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
344
345
    """Return argparse kwargs for the given Config dataclass.

346
347
348
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

349
350
351
352
353
354
355
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


356
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
357
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
358
    """Arguments for vLLM engine."""
359

360
    model: str = ModelConfig.model
361
    enable_return_routed_experts: bool = ModelConfig.enable_return_routed_experts
362
    model_weights: str = ModelConfig.model_weights
363
    served_model_name: str | list[str] | None = ModelConfig.served_model_name
364
    tokenizer: str | None = ModelConfig.tokenizer
365
    hf_config_path: str | None = ModelConfig.hf_config_path
366
367
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
368
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
369
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
370
    tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
371
    trust_remote_code: bool = ModelConfig.trust_remote_code
372
373
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
    allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
374
    download_dir: str | None = LoadConfig.download_dir
375
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
376
    load_format: str | LoadFormats = LoadConfig.load_format
377
378
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
379
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
380
    seed: int = ModelConfig.seed
381
    max_model_len: int = ModelConfig.max_model_len
382
383
384
385
386
387
    cudagraph_capture_sizes: list[int] | None = (
        CompilationConfig.cudagraph_capture_sizes
    )
    max_cudagraph_capture_size: int | None = get_field(
        CompilationConfig, "max_cudagraph_capture_size"
    )
388
389
390
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
391
    distributed_executor_backend: (
392
        str | DistributedExecutorBackend | type[Executor] | None
393
    ) = ParallelConfig.distributed_executor_backend
394
    # number of P/D disaggregation (or other disaggregation) workers
395
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
396
397
398
399
    master_addr: str = ParallelConfig.master_addr
    master_port: int = ParallelConfig.master_port
    nnodes: int = ParallelConfig.nnodes
    node_rank: int = ParallelConfig.node_rank
400
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
401
    prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
402
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
403
    dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
404
    cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
405
    data_parallel_size: int = ParallelConfig.data_parallel_size
406
407
408
409
410
    data_parallel_rank: int | None = None
    data_parallel_start_rank: int | None = None
    data_parallel_size_local: int | None = None
    data_parallel_address: str | None = None
    data_parallel_rpc_port: int | None = None
411
    data_parallel_hybrid_lb: bool = False
412
    data_parallel_external_lb: bool = False
413
    data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend
414
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
415
    all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
416
    enable_dbo: bool = ParallelConfig.enable_dbo
417
    ubatch_size: int = ParallelConfig.ubatch_size
418
419
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
420
    disable_nccl_for_dp_synchronization: bool | None = (
421
422
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
423
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
424
    enable_eplb: bool = ParallelConfig.enable_eplb
425
    expert_placement_strategy: ExpertPlacementStrategy = (
426
        ParallelConfig.expert_placement_strategy
427
    )
428
429
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
430
    max_parallel_loading_workers: int | None = (
431
432
        ParallelConfig.max_parallel_loading_workers
    )
433
    block_size: int = None  # type: ignore[assignment]
434
    enable_prefix_caching: bool | None = None
435
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
436
        CacheConfig.prefix_caching_hash_algo
437
    )
438
439
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
440
441
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
442
    cpu_offload_params: set[str] = get_field(CacheConfig, "cpu_offload_params")
443
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
444
    kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
445
    max_num_batched_tokens: int | None = None
446
447
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
448
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
449
    max_num_seqs: int | None = None
450
    max_logprobs: int = ModelConfig.max_logprobs
451
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
452
    disable_log_stats: bool = False
453
    aggregate_engine_logging: bool = False
454
455
456
    revision: str | None = ModelConfig.revision
    code_revision: str | None = ModelConfig.code_revision
    hf_token: bool | str | None = ModelConfig.hf_token
457
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
458
    tokenizer_revision: str | None = ModelConfig.tokenizer_revision
459
    quantization: QuantizationMethods | str | None = ModelConfig.quantization
460
    allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
461
    enforce_eager: bool = ModelConfig.enforce_eager
462
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
463
    language_model_only: bool = MultiModalConfig.language_model_only
464
    limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
465
466
        MultiModalConfig, "limit_per_prompt"
    )
467
    enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
468
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
469
470
471
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
472
    mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
473
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
474
    mm_processor_cache_type: MMCacheType | None = (
475
        MultiModalConfig.mm_processor_cache_type
476
477
    )
    mm_shm_cache_max_object_size_mb: int = (
478
        MultiModalConfig.mm_shm_cache_max_object_size_mb
479
    )
480
    mm_encoder_only: bool = MultiModalConfig.mm_encoder_only
481
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
482
    mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
483
484
        MultiModalConfig.mm_encoder_attn_backend
    )
485
    io_processor_plugin: str | None = None
486
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
487
    video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
488
    # LoRA fields
489
    enable_lora: bool = False
490
    max_loras: int = LoRAConfig.max_loras
491
    max_lora_rank: MaxLoRARanks = LoRAConfig.max_lora_rank
492
    default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras
493
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
494
495
    max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
    lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
496
    enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
497
    specialize_active_lora: bool = LoRAConfig.specialize_active_lora
498

499
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
500
    num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
501
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
502
    ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
503

504
    enable_chunked_prefill: bool | None = None
505
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
506

507
    disable_hybrid_kv_cache_manager: bool | None = (
508
509
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
510

511
    structured_outputs_config: StructuredOutputsConfig = get_field(
512
513
        VllmConfig, "structured_outputs_config"
    )
514
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
515
    reasoning_parser_plugin: str | None = None
516

517
    speculative_config: dict[str, Any] | None = None
518

519
    show_hidden_metrics_for_version: str | None = (
520
        ObservabilityConfig.show_hidden_metrics_for_version
521
    )
522
523
    otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: list[DetailedTraceModules] | None = (
524
        ObservabilityConfig.collect_detailed_traces
525
    )
526
527
528
529
    kv_cache_metrics: bool = ObservabilityConfig.kv_cache_metrics
    kv_cache_metrics_sample: float = get_field(
        ObservabilityConfig, "kv_cache_metrics_sample"
    )
530
    cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
531
532
533
    enable_layerwise_nvtx_tracing: bool = (
        ObservabilityConfig.enable_layerwise_nvtx_tracing
    )
534
    enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
535
536
537
    enable_logging_iteration_details: bool = (
        ObservabilityConfig.enable_logging_iteration_details
    )
538
    enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
539
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
540
    scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
541

542
    pooler_config: PoolerConfig | None = ModelConfig.pooler_config
543
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
544
    attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
545
546
547
548
    kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
    enable_flashinfer_autotune: bool = get_field(
        KernelConfig, "enable_flashinfer_autotune"
    )
549
550
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
551

552
553
    profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")

554
555
    kv_transfer_config: KVTransferConfig | None = None
    kv_events_config: KVEventsConfig | None = None
556

557
558
    ec_transfer_config: ECTransferConfig | None = None

559
560
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
561
562
563
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
564
    model_impl: str = ModelConfig.model_impl
565
    override_attention_dtype: str | None = ModelConfig.override_attention_dtype
566
    attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
567

568
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
569
570
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
571
    mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
572
    mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
573

574
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
575

576
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
577
    pt_load_map_location: str | dict[str, str] = LoadConfig.pt_load_map_location
578

579
    logits_processors: list[str | type[LogitsProcessor]] | None = (
580
581
        ModelConfig.logits_processors
    )
582
583
    """Custom logitproc types"""

584
    async_scheduling: bool | None = SchedulerConfig.async_scheduling
585

586
587
    stream_interval: int = SchedulerConfig.stream_interval

588
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
589
    optimization_level: OptimizationLevel = VllmConfig.optimization_level
590

591
    kv_offloading_size: float | None = CacheConfig.kv_offloading_size
592
    kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
593
    tokens_only: bool = False
594

595
596
597
598
    weight_transfer_config: WeightTransferConfig | None = get_field(
        VllmConfig,
        "weight_transfer_config",
    )
599

600
601
    fail_on_environ_validation: bool = False

602
    def __post_init__(self):
603
604
605
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
606
        if isinstance(self.compilation_config, dict):
607
            self.compilation_config = CompilationConfig(**self.compilation_config)
608
609
        if isinstance(self.attention_config, dict):
            self.attention_config = AttentionConfig(**self.attention_config)
610
611
        if isinstance(self.kernel_config, dict):
            self.kernel_config = KernelConfig(**self.kernel_config)
612
        if isinstance(self.eplb_config, dict):
613
            self.eplb_config = EPLBConfig(**self.eplb_config)
614
615
616
617
        if isinstance(self.weight_transfer_config, dict):
            self.weight_transfer_config = WeightTransferConfig(
                **self.weight_transfer_config
            )
618
        # Setup plugins
619
        from vllm.plugins import load_general_plugins
620

621
        load_general_plugins()
622
        # when use hf offline,replace model and tokenizer id to local model path
623
624
625
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
            if model_id is not self.model:
                logger.info(
                    "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                    model_id,
                    self.model,
                )
            if self.tokenizer is not None:
                tokenizer_id = self.tokenizer
                self.tokenizer = get_model_path(self.tokenizer, self.tokenizer_revision)
                if tokenizer_id is not self.tokenizer:
                    logger.info(
                        "HF_HUB_OFFLINE is True, replace tokenizer_id [%s] "
                        "to tokenizer_path [%s]",
                        tokenizer_id,
                        self.tokenizer,
                    )
642
643

    @staticmethod
644
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
645
        """Shared CLI arguments for vLLM engine."""
646

647
        # Model arguments
648
649
650
651
652
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
653
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
654
            model_group.add_argument("--model", **model_kwargs["model"])
655
656
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
657
658
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
659
660
661
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
662
663
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
664
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
665
666
667
668
669
670
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
671
        model_group.add_argument("--revision", **model_kwargs["revision"])
672
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
673
674
675
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
676
677
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
678
679
680
681
        model_group.add_argument(
            "--allow-deprecated-quantization",
            **model_kwargs["allow_deprecated_quantization"],
        )
682
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
683
684
685
686
        model_group.add_argument(
            "--enable-return-routed-experts",
            **model_kwargs["enable_return_routed_experts"],
        )
687
688
689
690
691
692
693
694
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
695
696
697
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
698
699
700
701
702
703
704
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
705
706
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
726
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
727
728
729
730
731
732
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
733
734
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
735
        )
736

737
738
739
740
741
742
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
743
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
744
745
746
747
748
749
750
751
752
753
754
755
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
756

757
758
759
760
761
762
763
764
765
766
        # Attention arguments
        attention_kwargs = get_kwargs(AttentionConfig)
        attention_group = parser.add_argument_group(
            title="AttentionConfig",
            description=AttentionConfig.__doc__,
        )
        attention_group.add_argument(
            "--attention-backend", **attention_kwargs["backend"]
        )

767
768
769
770
771
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
772
        )
773
        structured_outputs_group.add_argument(
774
            "--reasoning-parser",
775
            # Choices need to be validated after parsing to include plugins
776
777
            **structured_outputs_kwargs["reasoning_parser"],
        )
778
779
780
781
        structured_outputs_group.add_argument(
            "--reasoning-parser-plugin",
            **structured_outputs_kwargs["reasoning_parser_plugin"],
        )
782

783
        # Parallel arguments
784
785
786
787
788
789
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
790
            "--distributed-executor-backend",
791
792
            **parallel_kwargs["distributed_executor_backend"],
        )
793
        parallel_group.add_argument(
794
795
796
797
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
798
799
800
801
        parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
        parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
        parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
        parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
802
        parallel_group.add_argument(
803
804
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
805
        parallel_group.add_argument(
806
807
808
809
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
810
811
812
813
        parallel_group.add_argument(
            "--dcp-kv-cache-interleave-size",
            **parallel_kwargs["dcp_kv_cache_interleave_size"],
        )
814
815
816
817
818
819
820
821
822
        parallel_group.add_argument(
            "--cp-kv-cache-interleave-size",
            **parallel_kwargs["cp_kv_cache_interleave_size"],
        )
        parallel_group.add_argument(
            "--prefill-context-parallel-size",
            "-pcp",
            **parallel_kwargs["prefill_context_parallel_size"],
        )
823
824
825
826
827
828
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
829
            type=int,
830
831
832
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
833
        parallel_group.add_argument(
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
864
        parallel_group.add_argument(
865
866
867
868
869
870
871
872
            "--data-parallel-hybrid-lb",
            "-dph",
            **parallel_kwargs["data_parallel_hybrid_lb"],
        )
        parallel_group.add_argument(
            "--data-parallel-external-lb",
            "-dpe",
            **parallel_kwargs["data_parallel_external_lb"],
873
874
        )
        parallel_group.add_argument(
875
876
877
            "--enable-expert-parallel",
            "-ep",
            **parallel_kwargs["enable_expert_parallel"],
878
        )
879
880
881
        parallel_group.add_argument(
            "--all2all-backend", **parallel_kwargs["all2all_backend"]
        )
882
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
883
884
885
886
        parallel_group.add_argument(
            "--ubatch-size",
            **parallel_kwargs["ubatch_size"],
        )
887
888
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
889
890
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
891
892
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
893
894
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
895
896
897
898
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
899
900
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
901
902
        parallel_group.add_argument(
            "--expert-placement-strategy",
903
904
            **parallel_kwargs["expert_placement_strategy"],
        )
905

906
        parallel_group.add_argument(
907
            "--max-parallel-loading-workers",
908
909
            **parallel_kwargs["max_parallel_loading_workers"],
        )
910
        parallel_group.add_argument(
911
912
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
913
        parallel_group.add_argument(
914
            "--disable-custom-all-reduce",
915
916
917
918
919
920
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
921

922
923
924
925
926
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
927
        )
928
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
929
930
931
932
933
934
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
935
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
936
937
938
939
940
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
941
942
943
944
945
            "--enable-prefix-caching",
            **{
                **cache_kwargs["enable_prefix_caching"],
                "default": None,
            },
946
947
948
949
950
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
951
952
953
        cache_group.add_argument(
            "--cpu-offload-params", **cache_kwargs["cpu_offload_params"]
        )
954
955
956
957
958
959
960
961
962
963
964
965
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
966
967
968
        cache_group.add_argument(
            "--mamba-block-size", **cache_kwargs["mamba_block_size"]
        )
969
970
971
        cache_group.add_argument(
            "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
        )
972
973
974
975
976
977
        cache_group.add_argument(
            "--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
        )
        cache_group.add_argument(
            "--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
        )
978

979
        # Multimodal related configs
980
981
982
983
984
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
985
986
987
        multimodal_group.add_argument(
            "--language-model-only", **multimodal_kwargs["language_model_only"]
        )
988
        multimodal_group.add_argument(
989
990
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
991
992
993
        multimodal_group.add_argument(
            "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
        )
994
995
996
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
997
998
999
1000
1001
1002
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
1003
        multimodal_group.add_argument(
1004
1005
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
1006
1007
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
1008
1009
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
1010
1011
1012
        multimodal_group.add_argument(
            "--mm-encoder-only", **multimodal_kwargs["mm_encoder_only"]
        )
1013
        multimodal_group.add_argument(
1014
1015
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
1016
1017
1018
1019
        multimodal_group.add_argument(
            "--mm-encoder-attn-backend",
            **multimodal_kwargs["mm_encoder_attn_backend"],
        )
1020
1021
1022
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
1023
        multimodal_group.add_argument(
1024
1025
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
1026

1027
        multimodal_group.add_argument(
1028
1029
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
1030

1031
        # LoRA related configs
1032
1033
1034
1035
1036
1037
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
1038
            "--enable-lora",
1039
            action=argparse.BooleanOptionalAction,
1040
1041
            help="If True, enable handling of LoRA adapters.",
        )
1042
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
1043
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
1044
        lora_group.add_argument(
1045
            "--lora-dtype",
1046
1047
            **lora_kwargs["lora_dtype"],
        )
1048
1049
1050
1051
        lora_group.add_argument(
            "--enable-tower-connector-lora",
            **lora_kwargs["enable_tower_connector_lora"],
        )
1052
1053
1054
1055
1056
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
1057
1058
1059
        lora_group.add_argument(
            "--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
        )
1060

1061
1062
1063
1064
1065
1066
1067
1068
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
1069
1070
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
1071
        observability_group.add_argument(
1072
1073
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
1074
1075
1076
1077
1078
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
1079
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
1080
1081
1082
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
1083
1084
            **observability_kwargs["collect_detailed_traces"],
        )
1085
1086
1087
1088
1089
1090
1091
        observability_group.add_argument(
            "--kv-cache-metrics", **observability_kwargs["kv_cache_metrics"]
        )
        observability_group.add_argument(
            "--kv-cache-metrics-sample",
            **observability_kwargs["kv_cache_metrics_sample"],
        )
1092
1093
1094
1095
        observability_group.add_argument(
            "--cudagraph-metrics",
            **observability_kwargs["cudagraph_metrics"],
        )
1096
1097
1098
1099
        observability_group.add_argument(
            "--enable-layerwise-nvtx-tracing",
            **observability_kwargs["enable_layerwise_nvtx_tracing"],
        )
1100
1101
1102
1103
        observability_group.add_argument(
            "--enable-mfu-metrics",
            **observability_kwargs["enable_mfu_metrics"],
        )
1104
1105
1106
1107
        observability_group.add_argument(
            "--enable-logging-iteration-details",
            **observability_kwargs["enable_logging_iteration_details"],
        )
1108

1109
1110
1111
1112
1113
1114
1115
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
1116
1117
1118
1119
1120
            "--max-num-batched-tokens",
            **{
                **scheduler_kwargs["max_num_batched_tokens"],
                "default": None,
            },
1121
        )
1122
        scheduler_group.add_argument(
1123
1124
1125
1126
1127
            "--max-num-seqs",
            **{
                **scheduler_kwargs["max_num_seqs"],
                "default": None,
            },
1128
1129
1130
1131
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
1132
1133
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
1134
1135
            **scheduler_kwargs["max_long_partial_prefills"],
        )
1136
1137
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
1138
1139
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
1140
1141
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
1142
        scheduler_group.add_argument(
1143
1144
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
1145
        scheduler_group.add_argument(
1146
1147
1148
1149
1150
            "--enable-chunked-prefill",
            **{
                **scheduler_kwargs["enable_chunked_prefill"],
                "default": None,
            },
1151
1152
1153
1154
1155
1156
1157
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1158
1159
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1160
1161
1162
1163
1164
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1165
1166
1167
        scheduler_group.add_argument(
            "--stream-interval", **scheduler_kwargs["stream_interval"]
        )
1168

1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
        # Compilation arguments
        compilation_kwargs = get_kwargs(CompilationConfig)
        compilation_group = parser.add_argument_group(
            title="CompilationConfig",
            description=CompilationConfig.__doc__,
        )
        compilation_group.add_argument(
            "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
        )
        compilation_group.add_argument(
            "--max-cudagraph-capture-size",
            **compilation_kwargs["max_cudagraph_capture_size"],
        )

1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
        # Kernel arguments
        kernel_kwargs = get_kwargs(KernelConfig)
        kernel_group = parser.add_argument_group(
            title="KernelConfig",
            description=KernelConfig.__doc__,
        )
        kernel_group.add_argument(
            "--enable-flashinfer-autotune",
            **kernel_kwargs["enable_flashinfer_autotune"],
        )

1194
        # vLLM arguments
1195
        vllm_kwargs = get_kwargs(VllmConfig)
1196
1197
1198
1199
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1200
1201
1202
1203
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1204
1205
1206
1207
1208
1209
1210
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
1211
1212
1213
        vllm_group.add_argument(
            "--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
        )
1214
        vllm_group.add_argument(
1215
            "--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
1216
        )
1217
1218
1219
        vllm_group.add_argument(
            "--attention-config", "-ac", **vllm_kwargs["attention_config"]
        )
1220
        vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
1221
1222
1223
1224
1225
1226
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1227
        vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
1228
1229
1230
        vllm_group.add_argument(
            "--optimization-level", **vllm_kwargs["optimization_level"]
        )
1231
1232
1233
        vllm_group.add_argument(
            "--weight-transfer-config", **vllm_kwargs["weight_transfer_config"]
        )
1234

1235
        # Other arguments
1236
1237
1238
1239
1240
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1241

1242
1243
1244
1245
1246
1247
        parser.add_argument(
            "--aggregate-engine-logging",
            action="store_true",
            help="Log aggregate rather than per-engine statistics "
            "when using data parallelism.",
        )
1248
1249
1250
1251
1252
1253
1254
1255

        parser.add_argument(
            "--fail-on-environ-validation",
            help="If set, the engine will raise an error if "
            "environment validation fails.",
            default=False,
            action=argparse.BooleanOptionalAction,
        )
1256
        return parser
1257
1258

    @classmethod
1259
    def from_cli_args(cls, args: argparse.Namespace):
1260
1261
1262
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1263
1264
1265
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1266
        return engine_args
1267

1268
    def create_model_config(self) -> ModelConfig:
1269
1270
        # gguf file needs a specific model loader
        if is_gguf(self.model):
1271
1272
            self.quantization = self.load_format = "gguf"

1273
1274
1275
1276
1277
1278
1279
        if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
            logger.warning(
                "The global random seed is set to %d. Since "
                "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
                "affect the random state of the Python process that "
                "launched vLLM.",
                self.seed,
1280
1281
            )

1282
        return ModelConfig(
1283
            model=self.model,
1284
            model_weights=self.model_weights,
1285
            hf_config_path=self.hf_config_path,
1286
1287
            runner=self.runner,
            convert=self.convert,
1288
            tokenizer=self.tokenizer,  # type: ignore[arg-type]
1289
            tokenizer_mode=self.tokenizer_mode,
1290
            trust_remote_code=self.trust_remote_code,
1291
1292
            allowed_local_media_path=self.allowed_local_media_path,
            allowed_media_domains=self.allowed_media_domains,
1293
1294
1295
1296
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
1297
            hf_token=self.hf_token,
1298
            hf_overrides=self.hf_overrides,
1299
            tokenizer_revision=self.tokenizer_revision,
1300
1301
            max_model_len=self.max_model_len,
            quantization=self.quantization,
1302
            allow_deprecated_quantization=self.allow_deprecated_quantization,
1303
            enforce_eager=self.enforce_eager,
1304
            enable_return_routed_experts=self.enable_return_routed_experts,
1305
            max_logprobs=self.max_logprobs,
1306
            logprobs_mode=self.logprobs_mode,
1307
            disable_sliding_window=self.disable_sliding_window,
1308
            disable_cascade_attn=self.disable_cascade_attn,
1309
            skip_tokenizer_init=self.skip_tokenizer_init,
1310
            enable_prompt_embeds=self.enable_prompt_embeds,
1311
            served_model_name=self.served_model_name,
1312
            language_model_only=self.language_model_only,
1313
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1314
            enable_mm_embeds=self.enable_mm_embeds,
1315
            interleave_mm_strings=self.interleave_mm_strings,
1316
            media_io_kwargs=self.media_io_kwargs,
1317
            skip_mm_profiling=self.skip_mm_profiling,
1318
            config_format=self.config_format,
1319
            mm_processor_kwargs=self.mm_processor_kwargs,
1320
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1321
            mm_processor_cache_type=self.mm_processor_cache_type,
1322
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1323
            mm_encoder_only=self.mm_encoder_only,
1324
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1325
            mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1326
            pooler_config=self.pooler_config,
1327
            generation_config=self.generation_config,
1328
            override_generation_config=self.override_generation_config,
1329
            enable_sleep_mode=self.enable_sleep_mode,
1330
            model_impl=self.model_impl,
1331
            override_attention_dtype=self.override_attention_dtype,
1332
            logits_processors=self.logits_processors,
1333
            video_pruning_rate=self.video_pruning_rate,
1334
            io_processor_plugin=self.io_processor_plugin,
1335
        )
1336

1337
    def validate_tensorizer_args(self):
1338
1339
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1340
1341
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1342
1343
1344
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1345

1346
    def create_load_config(self) -> LoadConfig:
1347
1348
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1349

1350
1351
1352
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1353
1354
                    self.model_loader_extra_config.to_serializable()
                )
1355
            self.model_loader_extra_config["tensorizer_config"] = {}
1356
1357
1358
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1359
            self.validate_tensorizer_args()
1360

1361
1362
1363
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1364
            safetensors_load_strategy=self.safetensors_load_strategy,
1365
1366
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1367
            use_tqdm_on_load=self.use_tqdm_on_load,
1368
            pt_load_map_location=self.pt_load_map_location,
1369
        )
1370

1371
1372
1373
1374
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
1375
    ) -> SpeculativeConfig | None:
1376
1377
1378
1379
1380
1381
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1382
        dictionary from the engine.
1383
1384
        """
        if self.speculative_config is None:
1385
            return None
1386

1387
1388
1389
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1390
1391
1392
1393
1394
1395
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
            }
        )
1396
        return SpeculativeConfig(**self.speculative_config)
1397

1398
1399
    def create_engine_config(
        self,
1400
        usage_context: UsageContext | None = None,
1401
        headless: bool = False,
1402
1403
1404
1405
    ) -> VllmConfig:
        """
        Create the VllmConfig.

1406
        NOTE: If VllmConfig is incompatible, we raise an error.
1407
        """
1408
        current_platform.pre_register_and_update()
1409

1410
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1411

1412
1413
        envs.validate_environ(self.fail_on_environ_validation)

1414
1415
        # Check if the model is a speculator and override model/tokenizer/config
        # BEFORE creating ModelConfig, so the config is created with the target model
1416
1417
1418
1419
        # Skip speculator detection for cloud storage models (eg: S3, GCS) since
        # HuggingFace cannot load configs directly from S3 URLs. S3 models can still
        # use speculators with explicit --speculative-config.
        if not is_cloud_storage(self.model):
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
            (self.model, self.tokenizer, self.speculative_config) = (
                maybe_override_with_speculators(
                    model=self.model,
                    tokenizer=self.tokenizer,
                    revision=self.revision,
                    trust_remote_code=self.trust_remote_code,
                    vllm_speculative_config=self.speculative_config,
                )
            )

1430
        model_config = self.create_model_config()
1431
        self.model = model_config.model
1432
        self.model_weights = model_config.model_weights
1433
1434
        self.tokenizer = model_config.tokenizer

1435
        self._check_feature_supported()
1436
1437
1438
1439
        self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
        self._set_default_max_num_seqs_and_batched_tokens_args(
            usage_context, model_config
        )
1440

1441
        sliding_window: int | None = None
1442
1443
1444
1445
1446
1447
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1448
1449
1450
1451
1452
        # Resolve "auto" kv_cache_dtype to actual value from model config
        resolved_cache_dtype = resolve_kv_cache_dtype_string(
            self.kv_cache_dtype, model_config
        )

1453
1454
1455
1456
        assert self.enable_prefix_caching is not None, (
            "enable_prefix_caching must be set by this point"
        )

1457
        cache_config = CacheConfig(
1458
            block_size=self.block_size,
1459
            gpu_memory_utilization=self.gpu_memory_utilization,
1460
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1461
            swap_space=self.swap_space,
1462
            cache_dtype=resolved_cache_dtype,  # type: ignore[arg-type]
1463
            is_attention_free=model_config.is_attention_free,
1464
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1465
            sliding_window=sliding_window,
1466
            enable_prefix_caching=self.enable_prefix_caching,
1467
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1468
            cpu_offload_gb=self.cpu_offload_gb,
1469
            cpu_offload_params=self.cpu_offload_params,
1470
            calculate_kv_scales=self.calculate_kv_scales,
1471
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1472
1473
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1474
            mamba_block_size=self.mamba_block_size,
1475
            mamba_cache_mode=self.mamba_cache_mode,
1476
1477
            kv_offloading_size=self.kv_offloading_size,
            kv_offloading_backend=self.kv_offloading_backend,
1478
        )
1479

1480
1481
1482
1483
1484
1485
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1486

1487
            ray_runtime_env = ray.get_runtime_context().runtime_env
1488
1489
1490
1491
1492
1493
1494
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1495

1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1507
        assert not headless or not self.data_parallel_hybrid_lb, (
1508
1509
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
        assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
            "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
        )
        assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
            "nnodes > 1 is only supported with data_parallel_backend=mp"
        )
        inferred_data_parallel_rank = 0
        if self.nnodes > 1:
            world_size = (
                self.data_parallel_size
                * self.pipeline_parallel_size
                * self.tensor_parallel_size
            )
            world_size_within_dp = (
                self.pipeline_parallel_size * self.tensor_parallel_size
            )
            local_world_size = world_size // self.nnodes
            assert world_size % self.nnodes == 0, (
                f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
            )
            assert self.node_rank < self.nnodes, (
                f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
            )
            inferred_data_parallel_rank = (
                self.node_rank * local_world_size
            ) // world_size_within_dp
            if self.data_parallel_size > 1 and self.data_parallel_external_lb:
                self.data_parallel_rank = inferred_data_parallel_rank
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d for external lb",
                    self.data_parallel_rank,
                    self.node_rank,
                )
            elif self.data_parallel_size_local is None:
                # Infer data parallel size local for internal dplb:
                self.data_parallel_size_local = max(
                    local_world_size // world_size_within_dp, 1
                )
        data_parallel_external_lb = (
            self.data_parallel_external_lb or self.data_parallel_rank is not None
        )
1551
        # Local DP rank = 1, use pure-external LB.
1552
        if data_parallel_external_lb:
1553
            assert self.data_parallel_rank is not None, (
1554
                "data_parallel_rank or node_rank must be specified if "
1555
1556
                "data_parallel_external_lb is enable."
            )
1557
            assert self.data_parallel_size_local in (1, None), (
1558
1559
                "data_parallel_size_local must be 1 or None when data_parallel_rank "
                "is set"
1560
            )
1561
            data_parallel_size_local = 1
1562
1563
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1564
1565
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1566
1567
1568
1569
1570
1571
1572

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
1573
1574
1575
1576
1577
                logger.warning(
                    "data_parallel_hybrid_lb is not eligible when "
                    "data_parallel_size_local = 1, autoswitch to "
                    "data_parallel_external_lb."
                )
1578
1579
1580
1581
1582
1583
1584
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

1585
1586
1587
1588
1589
1590
1591
1592
1593
            self.data_parallel_rank = (
                self.data_parallel_start_rank or inferred_data_parallel_rank
            )
            if self.nnodes > 1:
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d",
                    self.data_parallel_rank,
                    self.node_rank,
                )
1594
        else:
1595
            assert not self.data_parallel_hybrid_lb, (
1596
1597
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1598

1599
1600
1601
1602
1603
1604
1605
1606
1607
            if self.data_parallel_backend == "ray" and (
                envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
            ):
                # Data parallel size defaults to 1 if DP ranks are spanning
                # multiple nodes
                data_parallel_size_local = 1
            else:
                # Otherwise local DP size defaults to global DP size if not set
                data_parallel_size_local = self.data_parallel_size
1608
1609
1610

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1611
1612
1613
1614
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1615
1616
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1617
1618
1619
1620
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1621
1622
                    self.data_parallel_backend,
                )
1623
1624
1625
                data_parallel_address = (
                    self.master_addr or ParallelConfig.data_parallel_master_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1626
1627
        else:
            data_parallel_address = self.data_parallel_address
1628
1629
1630

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1631
        data_parallel_rpc_port = (
1632
            self.data_parallel_rpc_port
1633
1634
1635
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1636

1637
1638
1639
1640
        if self.tokens_only and not model_config.skip_tokenizer_init:
            model_config.skip_tokenizer_init = True
            logger.info("Skipping tokenizer initialization for tokens-only mode.")

1641
        parallel_config = ParallelConfig(
1642
1643
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1644
            prefill_context_parallel_size=self.prefill_context_parallel_size,
1645
            data_parallel_size=self.data_parallel_size,
1646
1647
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1648
            data_parallel_size_local=data_parallel_size_local,
1649
1650
1651
1652
            master_addr=self.master_addr,
            master_port=self.master_port,
            nnodes=self.nnodes,
            node_rank=self.node_rank,
1653
1654
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1655
            data_parallel_backend=self.data_parallel_backend,
1656
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1657
            is_moe_model=model_config.is_moe,
1658
            enable_expert_parallel=self.enable_expert_parallel,
1659
            all2all_backend=self.all2all_backend,
1660
            enable_dbo=self.enable_dbo,
1661
            ubatch_size=self.ubatch_size,
1662
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1663
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1664
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1665
            enable_eplb=self.enable_eplb,
1666
            eplb_config=self.eplb_config,
1667
            expert_placement_strategy=self.expert_placement_strategy,
1668
1669
1670
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1671
            ray_runtime_env=ray_runtime_env,
1672
            placement_group=placement_group,
1673
1674
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1675
            worker_extension_cls=self.worker_extension_cls,
1676
            decode_context_parallel_size=self.decode_context_parallel_size,
1677
            dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
1678
            cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
1679
1680
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1681
        )
1682

1683
        speculative_config = self.create_speculative_config(
1684
1685
1686
1687
            target_model_config=model_config,
            target_parallel_config=parallel_config,
        )

1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
        assert self.max_num_batched_tokens is not None, (
            "max_num_batched_tokens must be set by this point"
        )
        assert self.max_num_seqs is not None, "max_num_seqs must be set by this point"
        assert self.enable_chunked_prefill is not None, (
            "enable_chunked_prefill must be set by this point"
        )
        assert model_config.max_model_len is not None, (
            "max_model_len must be set by this point"
        )
1698
        scheduler_config = SchedulerConfig(
1699
            runner_type=model_config.runner_type,
1700
1701
1702
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1703
            enable_chunked_prefill=self.enable_chunked_prefill,
1704
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1705
            is_multimodal_model=model_config.is_multimodal_model,
1706
            is_encoder_decoder=model_config.is_encoder_decoder,
1707
            policy=self.scheduling_policy,
1708
            scheduler_cls=self.scheduler_cls,
1709
1710
1711
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1712
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1713
            async_scheduling=self.async_scheduling,
1714
            stream_interval=self.stream_interval,
1715
        )
1716

1717
1718
1719
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_dtype=self.lora_dtype,
1730
                enable_tower_connector_lora=self.enable_tower_connector_lora,
1731
                specialize_active_lora=self.specialize_active_lora,
1732
1733
1734
1735
1736
1737
1738
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1739

1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
        if (
            lora_config is not None
            and speculative_config is not None
            and scheduler_config.max_num_batched_tokens
            < (
                scheduler_config.max_num_seqs
                * (speculative_config.num_speculative_tokens + 1)
            )
        ):
            raise ValueError(
                "Consider increasing max_num_batched_tokens or "
                "decreasing num_speculative_tokens"
            )

1754
1755
1756
1757
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1758
1759
1760
1761
1762
1763
1764
1765
        # Attention config overrides
        attention_config = copy.deepcopy(self.attention_config)
        if self.attention_backend is not None:
            if attention_config.backend is not None:
                raise ValueError(
                    "attention_backend and attention_config.backend "
                    "are mutually exclusive"
                )
1766
1767
1768
1769
1770
1771
1772
            # Convert string to enum if needed (CLI parsing returns a string)
            if isinstance(self.attention_backend, str):
                attention_config.backend = AttentionBackendEnum[
                    self.attention_backend.upper()
                ]
            else:
                attention_config.backend = self.attention_backend
1773

1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
        # Kernel config overrides
        kernel_config = copy.deepcopy(self.kernel_config)
        if self.enable_flashinfer_autotune is not None:
            if kernel_config.enable_flashinfer_autotune is not None:
                raise ValueError(
                    "enable_flashinfer_autotune and "
                    "kernel_config.enable_flashinfer_autotune "
                    "are mutually exclusive"
                )
            kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune

1785
        load_config = self.create_load_config()
1786

1787
1788
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1789
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1790

1791
1792
1793
1794
1795
        if self.reasoning_parser_plugin:
            self.structured_outputs_config.reasoning_parser_plugin = (
                self.reasoning_parser_plugin
            )

1796
        observability_config = ObservabilityConfig(
1797
            show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
1798
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1799
            collect_detailed_traces=self.collect_detailed_traces,
1800
1801
            kv_cache_metrics=self.kv_cache_metrics,
            kv_cache_metrics_sample=self.kv_cache_metrics_sample,
1802
            cudagraph_metrics=self.cudagraph_metrics,
1803
            enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
1804
            enable_mfu_metrics=self.enable_mfu_metrics,
1805
            enable_mm_processor_stats=self.enable_mm_processor_stats,
1806
            enable_logging_iteration_details=self.enable_logging_iteration_details,
1807
        )
1808

1809
        # Compilation config overrides
1810
        compilation_config = copy.deepcopy(self.compilation_config)
1811
        if self.cudagraph_capture_sizes is not None:
1812
            if compilation_config.cudagraph_capture_sizes is not None:
1813
1814
1815
1816
                raise ValueError(
                    "cudagraph_capture_sizes and compilation_config."
                    "cudagraph_capture_sizes are mutually exclusive"
                )
1817
            compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes
1818
        if self.max_cudagraph_capture_size is not None:
1819
            if compilation_config.max_cudagraph_capture_size is not None:
1820
1821
1822
1823
                raise ValueError(
                    "max_cudagraph_capture_size and compilation_config."
                    "max_cudagraph_capture_size are mutually exclusive"
                )
1824
            compilation_config.max_cudagraph_capture_size = (
1825
1826
                self.max_cudagraph_capture_size
            )
1827
        config = VllmConfig(
1828
1829
1830
1831
1832
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
1833
1834
            load_config=load_config,
            attention_config=attention_config,
1835
            kernel_config=kernel_config,
1836
1837
            lora_config=lora_config,
            speculative_config=speculative_config,
1838
            structured_outputs_config=self.structured_outputs_config,
1839
            observability_config=observability_config,
1840
            compilation_config=compilation_config,
1841
            kv_transfer_config=self.kv_transfer_config,
1842
            kv_events_config=self.kv_events_config,
1843
            ec_transfer_config=self.ec_transfer_config,
1844
            profiler_config=self.profiler_config,
1845
            additional_config=self.additional_config,
1846
            optimization_level=self.optimization_level,
1847
            weight_transfer_config=self.weight_transfer_config,
1848
        )
1849

1850
1851
        return config

1852
    def _check_feature_supported(self):
1853
        """Raise an error if the feature is not supported."""
1854
        # No Concurrent Partial Prefills so far.
1855
1856
1857
1858
1859
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
1860
            _raise_unsupported_error(feature_name="Concurrent Partial Prefill")
1861

1862
        if self.pipeline_parallel_size > 1:
1863
1864
1865
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1866
            if not supports_pp and self.distributed_executor_backend not in (
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
1877
                _raise_unsupported_error(feature_name=name)
1878

1879
1880
1881
1882
1883
1884
    @classmethod
    def get_batch_defaults(
        cls,
        world_size: int,
    ) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
        from vllm.usage.usage_lib import UsageContext
1885

1886
1887
        default_max_num_batched_tokens: dict[UsageContext | None, int]
        default_max_num_seqs: dict[UsageContext | None, int]
1888

1889
1890
        # When no user override, set the default values based on the usage
        # context.
1891
        # Use different default values for different hardware.
1892
1893
1894
1895
1896
1897
1898

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1899
            device_memory = current_platform.get_device_total_memory()
1900
            device_name = current_platform.get_device_name().lower()
1901
1902
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1903
            device_memory = 0
1904
            device_name = ""
1905

1906
1907
1908
1909
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1910
            # For GPUs like H100 and MI300x, use larger default values.
1911
1912
1913
1914
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1915
1916
1917
1918
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1919
1920
1921
1922
1923
1924
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1925
1926
1927
1928
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1929

1930
1931
        # tpu specific default values.
        if current_platform.is_tpu():
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
            chip_name = current_platform.get_device_name()

            if chip_name == "V6E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 2048,
                    UsageContext.OPENAI_API_SERVER: 1024,
                }
            elif chip_name == "V5E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 1024,
                    UsageContext.OPENAI_API_SERVER: 512,
                }
            elif chip_name == "V5P":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 512,
                    UsageContext.OPENAI_API_SERVER: 256,
                }
1949

1950
1951
1952
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
1953
1954
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1955
1956
            }
            default_max_num_seqs = {
1957
1958
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1959
1960
            }

1961
1962
        return default_max_num_batched_tokens, default_max_num_seqs

1963
1964
    def _set_default_chunked_prefill_and_prefix_caching_args(
        self, model_config: ModelConfig
1965
    ) -> None:
1966
1967
        default_chunked_prefill = model_config.is_chunked_prefill_supported
        default_prefix_caching = model_config.is_prefix_caching_supported
1968
1969
1970
1971
1972
1973
1974
1975

        if self.enable_chunked_prefill is None:
            self.enable_chunked_prefill = default_chunked_prefill

            logger.debug(
                "%s chunked prefill by default",
                "Enabling" if default_chunked_prefill else "Disabling",
            )
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
        elif (
            model_config.runner_type == "generate"
            and not self.enable_chunked_prefill
            and default_chunked_prefill
        ):
            logger.warning_once(
                "This model does not officially support disabling chunked prefill. "
                "Disabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
                scope="local",
            )
1987
1988
1989
1990
        elif (
            model_config.runner_type == "pooling"
            and self.enable_chunked_prefill
            and not default_chunked_prefill
1991
        ):
1992
            logger.warning_once(
1993
1994
1995
                "This model does not officially support chunked prefill. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
1996
                scope="local",
1997
1998
1999
2000
2001
            )

        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = default_prefix_caching

2002
            logger.debug(
2003
2004
2005
2006
2007
2008
2009
2010
                "%s prefix caching by default",
                "Enabling" if default_prefix_caching else "Disabling",
            )
        elif (
            model_config.runner_type == "pooling"
            and self.enable_prefix_caching
            and not default_prefix_caching
        ):
2011
            logger.warning_once(
2012
2013
2014
                "This model does not officially support prefix caching. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
2015
                scope="local",
2016
2017
            )

2018
2019
2020
2021
2022
2023
2024
2025
        # Disable chunked prefill and prefix caching for:
        # POWER (ppc64le)/s390x/RISCV CPUs in V1
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.S390X,
            CpuArchEnum.RISCV,
        ):
            logger.info(
2026
                "Chunked prefill is not supported for POWER, "
2027
2028
2029
2030
2031
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
            self.enable_chunked_prefill = False
            logger.info(
2032
                "Prefix caching is not supported for POWER, "
2033
2034
2035
2036
2037
2038
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
            self.enable_prefix_caching = False

    def _set_default_max_num_seqs_and_batched_tokens_args(
2039
2040
2041
        self,
        usage_context: UsageContext | None,
        model_config: ModelConfig,
2042
    ):
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
        world_size = self.pipeline_parallel_size * self.tensor_parallel_size
        (
            default_max_num_batched_tokens,
            default_max_num_seqs,
        ) = self.get_batch_defaults(world_size)

        orig_max_num_batched_tokens = self.max_num_batched_tokens
        orig_max_num_seqs = self.max_num_seqs

        if self.max_num_batched_tokens is None:
            self.max_num_batched_tokens = default_max_num_batched_tokens.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )

        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
            )

        if orig_max_num_batched_tokens is None:
2065
2066
2067
            assert model_config.max_model_len is not None, (
                "max_model_len must be set by this point"
            )
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
            if not self.enable_chunked_prefill:
                # If max_model_len is too short, use the default for higher throughput.
                self.max_num_batched_tokens = max(
                    model_config.max_model_len,
                    self.max_num_batched_tokens,
                )

            # When using default settings,
            # Ensure max_num_batched_tokens does not exceed model limit.
            # Some models (e.g., Whisper) have embeddings tied to max length.
            self.max_num_batched_tokens = min(
                self.max_num_seqs * model_config.max_model_len,
2080
2081
                self.max_num_batched_tokens,
            )
2082

2083
2084
2085
2086
            logger.debug(
                "Defaulting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens,
                usage_context.value if usage_context else None,
2087
            )
2088

2089
2090
2091
2092
        if orig_max_num_seqs is None:
            assert self.max_num_batched_tokens is not None  # For type checking
            self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

2093
            logger.debug(
2094
                "Defaulting max_num_seqs to %d for %s usage context.",
2095
                self.max_num_seqs,
2096
                usage_context.value if usage_context else None,
2097
            )
2098

2099

2100
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
2101
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
2102
    """Arguments for asynchronous vLLM engine."""
2103

2104
2105
    enable_log_requests: bool = False

2106
    @staticmethod
2107
2108
2109
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
2110
        # Initialize plugin to update the parser, for example, The plugin may
2111
        # add a new kind of quantization method to --quantization argument or
2112
2113
        # a new device to --device argument.
        load_general_plugins()
2114
2115
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
2129
        current_platform.pre_register_and_update(parser)
2130
        return parser
2131
2132


2133
2134
2135
2136
2137
2138
def _raise_unsupported_error(feature_name: str):
    msg = (
        f"{feature_name} is not supported. We recommend to "
        f"remove {feature_name} from your config."
    )
    raise NotImplementedError(msg)
2139
2140


2141
def human_readable_int(value: str) -> int:
2142
2143
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
2144

2145
2146
2147
2148
2149
2150
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
2151

2152
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
2153
2154
    if match:
        decimal_multiplier = {
2155
2156
2157
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
2158
            "t": 10**12,
2159
2160
        }
        binary_multiplier = {
2161
2162
2163
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
2164
            "T": 2**40,
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
2177
2178
2179
2180
2181
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
2182
2183
2184

    # Regular plain number.
    return int(value)
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203


def human_readable_int_or_auto(value: str) -> int:
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    Also accepts -1 or 'auto' as a special value for auto-detection.

    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    - '-1' or 'auto' -> -1 (special value for auto-detection)
    """
    value = value.strip()

    if value == "-1" or value.lower() == "auto":
        return -1

    return human_readable_int(value)