arg_utils.py 86.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from collections.abc import Callable
11
from dataclasses import MISSING, dataclass, fields, is_dataclass
12
from itertools import permutations
13
from types import UnionType
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Literal,
19
    TypeAlias,
20
21
22
23
24
25
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
26

27
import huggingface_hub
28
import regex as re
29
import torch
30
from pydantic import TypeAdapter, ValidationError
31
from pydantic.fields import FieldInfo
32
from typing_extensions import TypeIs
33

34
import vllm.envs as envs
35
from vllm.config import (
36
    AttentionConfig,
37
38
39
40
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
41
    ECTransferConfig,
42
43
44
45
46
47
    EPLBConfig,
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
    ModelConfig,
48
    MultiModalConfig,
49
50
51
    ObservabilityConfig,
    ParallelConfig,
    PoolerConfig,
52
    ProfilerConfig,
53
54
55
56
57
58
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
    VllmConfig,
    get_attr_docs,
)
59
60
61
62
63
64
65
from vllm.config.cache import (
    BlockSize,
    CacheDType,
    KVOffloadingBackend,
    MambaDType,
    PrefixCachingHashAlgo,
)
66
67
68
69
70
71
72
from vllm.config.device import Device
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
73
    TokenizerMode,
74
75
76
77
78
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
79
from vllm.config.utils import get_field
80
from vllm.config.vllm import OptimizationLevel
81
from vllm.logger import init_logger, suppress_logging
82
from vllm.platforms import CpuArchEnum, current_platform
83
from vllm.plugins import load_general_plugins
84
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
85
86
87
88
from vllm.transformers_utils.config import (
    is_interleaved,
    maybe_override_with_speculators,
)
89
from vllm.transformers_utils.gguf_utils import is_gguf
90
from vllm.transformers_utils.repo_utils import get_model_path
91
from vllm.transformers_utils.utils import is_cloud_storage
92
from vllm.utils.argparse_utils import FlexibleArgumentParser
93
from vllm.utils.mem_constants import GiB_bytes
94
from vllm.utils.network_utils import get_ip
95
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
96
from vllm.v1.attention.backends.registry import AttentionBackendEnum
97
from vllm.v1.sample.logits_processor import LogitsProcessor
98

99
100
if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
101
    from vllm.model_executor.model_loader import LoadFormats
102
    from vllm.usage.usage_lib import UsageContext
103
    from vllm.v1.executor import Executor
104
else:
105
    Executor = Any
106
    QuantizationMethods = Any
107
    LoadFormats = Any
108
109
    UsageContext = Any

110

111
112
logger = init_logger(__name__)

113
114
# object is used to allow for special typing forms
T = TypeVar("T")
115
116
TypeHint: TypeAlias = type[Any] | object
TypeHintT: TypeAlias = type[T] | object
117

118

119
120
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
121
122
123
124
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
125
126
                f"Value {val} cannot be converted to {return_type}."
            ) from e
127

128
129
130
    return _parse_type


131
132
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]:
    def _optional_type(val: str) -> T | None:
133
134
135
136
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

137
    return _optional_type
138
139


140
def union_dict_and_str(val: str) -> str | dict[str, str] | None:
141
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
142
        return str(val)
143
    return optional_type(json.loads)(val)
144
145


146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


161
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
162
163
164
165
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
166
    type_hint = get_type(type_hints, Literal)
167
168
169
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
170
        raise ValueError(
171
            "All options must be of the same type. "
172
173
            f"Got {options} with types {[type(c) for c in options]}"
        )
174
175
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
176
177


178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]:
    type_hint = get_type(type_hints, type)
    types = get_args(type_hint)
    elem_type = types[0]

    # Handle Ellipsis
    assert all(t is elem_type for t in types if t is not Ellipsis), (
        f"All non-Ellipsis elements must be of the same type. Got {types}."
    )

    # Handle Union types
    if get_origin(elem_type) in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
        assert str in get_args(elem_type), (
            "If element can have multiple types, one must be 'str' "
            f"(i.e. 'list[int | str]'). Got {elem_type}."
        )
        elem_type = str

    return {
        "type": elem_type,
        "nargs": "+" if type is not tuple or Ellipsis in types else len(types),
    }


203
204
205
206
207
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


208
209
210
211
212
213
214
215
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
216
217
    elif origin in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
218
219
220
221
222
223
224
225
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


226
NEEDS_HELP = (
227
228
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
229
230
231
232
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


233
@functools.lru_cache(maxsize=30)
234
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
235
236
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
237
238
    kwargs = {}
    for field in fields(cls):
239
        # Get the set of possible types for the field
240
        type_hints: set[TypeHint] = get_type_hints(field.type)
241
242
243
244
245

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

246
        # Get the default value of the field
247
248
        if field.default is not MISSING:
            default = field.default
249
250
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
251
252
253
254
255
256
257
                if default.default_factory is None:
                    default = default.default
                else:
                    # VllmConfig's Fields have default_factory set to config classes.
                    # These could emit logs on init, which would be confusing.
                    with suppress_logging():
                        default = default.default_factory()
258
        elif field.default_factory is not MISSING:
259
            default = field.default_factory()
260
261
262

        # Get the help text for the field
        name = field.name
263
        help = cls_docs.get(name, "").strip()
264
265
266
267
268
269
270
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
271
272
273
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
274
        if dataclass_cls is not None:
275
276
277
278
279
280
281
282

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
283
            kwargs[name]["help"] += f"\n\n{json_tip}"
284
        elif contains_type(type_hints, bool):
285
286
287
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
288
            kwargs[name].update(literal_to_kwargs(type_hints))
289
        elif contains_type(type_hints, tuple):
290
            kwargs[name].update(collection_to_kwargs(type_hints, tuple))
291
        elif contains_type(type_hints, list):
292
293
294
            kwargs[name].update(collection_to_kwargs(type_hints, list))
        elif contains_type(type_hints, set):
            kwargs[name].update(collection_to_kwargs(type_hints, set))
295
        elif contains_type(type_hints, int):
296
297
298
299
            if name == "max_model_len":
                kwargs[name]["type"] = human_readable_int_or_auto
                kwargs[name]["help"] += f"\n\n{human_readable_int_or_auto.__doc__}"
            elif name in ("max_num_batched_tokens", "kv_cache_memory_bytes"):
300
                kwargs[name]["type"] = human_readable_int
301
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
302
303
            else:
                kwargs[name]["type"] = int
304
305
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
306
307
308
309
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
310
            kwargs[name]["type"] = union_dict_and_str
311
        elif contains_type(type_hints, dict):
312
            kwargs[name]["type"] = parse_type(json.loads)
313
            kwargs[name]["help"] += f"\n\n{json_tip}"
314
315
316
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
317
318
            kwargs[name]["type"] = str
        else:
319
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
320

321
322
323
324
325
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

326
327
328
329
330
331
332
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
333
334


335
def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
336
337
    """Return argparse kwargs for the given Config dataclass.

338
339
340
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

341
342
343
344
345
346
347
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


348
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
349
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
350
    """Arguments for vLLM engine."""
351

352
    model: str = ModelConfig.model
353
    enable_return_routed_experts: bool = ModelConfig.enable_return_routed_experts
354
    model_weights: str = ModelConfig.model_weights
355
    served_model_name: str | list[str] | None = ModelConfig.served_model_name
356
    tokenizer: str | None = ModelConfig.tokenizer
357
    hf_config_path: str | None = ModelConfig.hf_config_path
358
359
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
360
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
361
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
362
    tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
363
    trust_remote_code: bool = ModelConfig.trust_remote_code
364
365
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
    allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
366
    download_dir: str | None = LoadConfig.download_dir
367
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
368
    load_format: str | LoadFormats = LoadConfig.load_format
369
370
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
371
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
372
    seed: int = ModelConfig.seed
373
    max_model_len: int | None = ModelConfig.max_model_len
374
375
376
377
378
379
    cudagraph_capture_sizes: list[int] | None = (
        CompilationConfig.cudagraph_capture_sizes
    )
    max_cudagraph_capture_size: int | None = get_field(
        CompilationConfig, "max_cudagraph_capture_size"
    )
380
381
382
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
383
    distributed_executor_backend: (
384
        str | DistributedExecutorBackend | type[Executor] | None
385
    ) = ParallelConfig.distributed_executor_backend
386
    # number of P/D disaggregation (or other disaggregation) workers
387
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
388
389
390
391
    master_addr: str = ParallelConfig.master_addr
    master_port: int = ParallelConfig.master_port
    nnodes: int = ParallelConfig.nnodes
    node_rank: int = ParallelConfig.node_rank
392
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
393
    prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
394
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
395
    dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
396
    cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
397
    data_parallel_size: int = ParallelConfig.data_parallel_size
398
399
400
401
402
    data_parallel_rank: int | None = None
    data_parallel_start_rank: int | None = None
    data_parallel_size_local: int | None = None
    data_parallel_address: str | None = None
    data_parallel_rpc_port: int | None = None
403
    data_parallel_hybrid_lb: bool = False
404
    data_parallel_external_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
405
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
406
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
407
    all2all_backend: str = ParallelConfig.all2all_backend
408
    enable_dbo: bool = ParallelConfig.enable_dbo
409
    ubatch_size: int = ParallelConfig.ubatch_size
410
411
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
412
    disable_nccl_for_dp_synchronization: bool | None = (
413
414
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
415
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
416
    enable_eplb: bool = ParallelConfig.enable_eplb
417
    expert_placement_strategy: ExpertPlacementStrategy = (
418
        ParallelConfig.expert_placement_strategy
419
    )
420
421
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
422
    max_parallel_loading_workers: int | None = (
423
424
        ParallelConfig.max_parallel_loading_workers
    )
425
    block_size: BlockSize | None = CacheConfig.block_size
426
    enable_prefix_caching: bool | None = None
427
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
428
        CacheConfig.prefix_caching_hash_algo
429
    )
430
431
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
432
433
434
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
435
    kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
436
    max_num_batched_tokens: int | None = None
437
438
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
439
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
440
    max_num_seqs: int | None = None
441
    max_logprobs: int = ModelConfig.max_logprobs
442
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
443
    disable_log_stats: bool = False
444
    aggregate_engine_logging: bool = False
445
446
447
    revision: str | None = ModelConfig.revision
    code_revision: str | None = ModelConfig.code_revision
    hf_token: bool | str | None = ModelConfig.hf_token
448
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
449
    tokenizer_revision: str | None = ModelConfig.tokenizer_revision
450
    quantization: QuantizationMethods | None = ModelConfig.quantization
451
    allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
452
    enforce_eager: bool = ModelConfig.enforce_eager
453
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
454
    limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
455
456
        MultiModalConfig, "limit_per_prompt"
    )
457
    enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
458
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
459
460
461
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
462
    mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
463
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
464
    mm_processor_cache_type: MMCacheType | None = (
465
        MultiModalConfig.mm_processor_cache_type
466
467
    )
    mm_shm_cache_max_object_size_mb: int = (
468
        MultiModalConfig.mm_shm_cache_max_object_size_mb
469
    )
470
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
471
    mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
472
473
        MultiModalConfig.mm_encoder_attn_backend
    )
474
    io_processor_plugin: str | None = None
475
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
476
    video_pruning_rate: float = MultiModalConfig.video_pruning_rate
477
    # LoRA fields
478
    enable_lora: bool = False
479
480
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
481
    default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras
482
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
483
484
    max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
    lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
485
    enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
486

487
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
488
    num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
489
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
490
    ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
491

492
    enable_chunked_prefill: bool | None = None
493
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
494

495
    disable_hybrid_kv_cache_manager: bool | None = (
496
497
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
498

499
    structured_outputs_config: StructuredOutputsConfig = get_field(
500
501
        VllmConfig, "structured_outputs_config"
    )
502
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
503
    reasoning_parser_plugin: str | None = None
504

505
    logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
506

507
    speculative_config: dict[str, Any] | None = None
508

509
    show_hidden_metrics_for_version: str | None = (
510
        ObservabilityConfig.show_hidden_metrics_for_version
511
    )
512
513
    otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: list[DetailedTraceModules] | None = (
514
        ObservabilityConfig.collect_detailed_traces
515
    )
516
517
518
519
    kv_cache_metrics: bool = ObservabilityConfig.kv_cache_metrics
    kv_cache_metrics_sample: float = get_field(
        ObservabilityConfig, "kv_cache_metrics_sample"
    )
520
    cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
521
522
523
    enable_layerwise_nvtx_tracing: bool = (
        ObservabilityConfig.enable_layerwise_nvtx_tracing
    )
524
    enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
525
526
527
    enable_logging_iteration_details: bool = (
        ObservabilityConfig.enable_logging_iteration_details
    )
528
    enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
529
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
530
    scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
531

532
    pooler_config: PoolerConfig | None = ModelConfig.pooler_config
533
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
534
    attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
535
536
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
537

538
539
    profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")

540
541
    kv_transfer_config: KVTransferConfig | None = None
    kv_events_config: KVEventsConfig | None = None
542

543
544
    ec_transfer_config: ECTransferConfig | None = None

545
546
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
547
548
549
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
550
    model_impl: str = ModelConfig.model_impl
551
    override_attention_dtype: str = ModelConfig.override_attention_dtype
552
    attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
553

554
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
555
556
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
557
    mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
558

559
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
560

561
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
562
    pt_load_map_location: str = LoadConfig.pt_load_map_location
563

564
    logits_processors: list[str | type[LogitsProcessor]] | None = (
565
566
        ModelConfig.logits_processors
    )
567
568
    """Custom logitproc types"""

569
    async_scheduling: bool | None = SchedulerConfig.async_scheduling
570

571
572
    stream_interval: int = SchedulerConfig.stream_interval

573
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
574
    optimization_level: OptimizationLevel = VllmConfig.optimization_level
575

576
    kv_offloading_size: float | None = CacheConfig.kv_offloading_size
577
    kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
578
    tokens_only: bool = False
579

580
    def __post_init__(self):
581
582
583
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
584
        if isinstance(self.compilation_config, dict):
585
            self.compilation_config = CompilationConfig(**self.compilation_config)
586
587
        if isinstance(self.attention_config, dict):
            self.attention_config = AttentionConfig(**self.attention_config)
588
        if isinstance(self.eplb_config, dict):
589
            self.eplb_config = EPLBConfig(**self.eplb_config)
590
        # Setup plugins
591
        from vllm.plugins import load_general_plugins
592

593
        load_general_plugins()
594
        # when use hf offline,replace model and tokenizer id to local model path
595
596
597
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
            if model_id is not self.model:
                logger.info(
                    "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                    model_id,
                    self.model,
                )
            if self.tokenizer is not None:
                tokenizer_id = self.tokenizer
                self.tokenizer = get_model_path(self.tokenizer, self.tokenizer_revision)
                if tokenizer_id is not self.tokenizer:
                    logger.info(
                        "HF_HUB_OFFLINE is True, replace tokenizer_id [%s] "
                        "to tokenizer_path [%s]",
                        tokenizer_id,
                        self.tokenizer,
                    )
614
615

    @staticmethod
616
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
617
        """Shared CLI arguments for vLLM engine."""
618

619
        # Model arguments
620
621
622
623
624
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
625
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
626
            model_group.add_argument("--model", **model_kwargs["model"])
627
628
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
629
630
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
631
632
633
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
634
635
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
636
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
637
638
639
640
641
642
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
643
        model_group.add_argument("--revision", **model_kwargs["revision"])
644
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
645
646
647
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
648
649
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
650
651
652
653
        model_group.add_argument(
            "--allow-deprecated-quantization",
            **model_kwargs["allow_deprecated_quantization"],
        )
654
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
655
656
657
658
        model_group.add_argument(
            "--enable-return-routed-experts",
            **model_kwargs["enable_return_routed_experts"],
        )
659
660
661
662
663
664
665
666
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
667
668
669
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
670
671
672
673
674
675
676
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
677
678
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
        )
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
701
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
702
703
704
705
706
707
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
708
709
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
710
        )
711

712
713
714
715
716
717
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
718
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
719
720
721
722
723
724
725
726
727
728
729
730
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
731

732
733
734
735
736
737
738
739
740
741
        # Attention arguments
        attention_kwargs = get_kwargs(AttentionConfig)
        attention_group = parser.add_argument_group(
            title="AttentionConfig",
            description=AttentionConfig.__doc__,
        )
        attention_group.add_argument(
            "--attention-backend", **attention_kwargs["backend"]
        )

742
743
744
745
746
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
747
        )
748
        structured_outputs_group.add_argument(
749
            "--reasoning-parser",
750
            # Choices need to be validated after parsing to include plugins
751
752
            **structured_outputs_kwargs["reasoning_parser"],
        )
753
754
755
756
        structured_outputs_group.add_argument(
            "--reasoning-parser-plugin",
            **structured_outputs_kwargs["reasoning_parser_plugin"],
        )
757

758
        # Parallel arguments
759
760
761
762
763
764
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
765
            "--distributed-executor-backend",
766
767
            **parallel_kwargs["distributed_executor_backend"],
        )
768
        parallel_group.add_argument(
769
770
771
772
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
773
774
775
776
        parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
        parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
        parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
        parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
777
        parallel_group.add_argument(
778
779
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
780
        parallel_group.add_argument(
781
782
783
784
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
785
786
787
788
        parallel_group.add_argument(
            "--dcp-kv-cache-interleave-size",
            **parallel_kwargs["dcp_kv_cache_interleave_size"],
        )
789
790
791
792
793
794
795
796
797
        parallel_group.add_argument(
            "--cp-kv-cache-interleave-size",
            **parallel_kwargs["cp_kv_cache_interleave_size"],
        )
        parallel_group.add_argument(
            "--prefill-context-parallel-size",
            "-pcp",
            **parallel_kwargs["prefill_context_parallel_size"],
        )
798
799
800
801
802
803
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
804
            type=int,
805
806
807
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
808
        parallel_group.add_argument(
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
839
        parallel_group.add_argument(
840
841
842
843
844
845
846
847
            "--data-parallel-hybrid-lb",
            "-dph",
            **parallel_kwargs["data_parallel_hybrid_lb"],
        )
        parallel_group.add_argument(
            "--data-parallel-external-lb",
            "-dpe",
            **parallel_kwargs["data_parallel_external_lb"],
848
849
        )
        parallel_group.add_argument(
850
851
852
            "--enable-expert-parallel",
            "-ep",
            **parallel_kwargs["enable_expert_parallel"],
853
        )
854
855
856
        parallel_group.add_argument(
            "--all2all-backend", **parallel_kwargs["all2all_backend"]
        )
857
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
858
859
860
861
        parallel_group.add_argument(
            "--ubatch-size",
            **parallel_kwargs["ubatch_size"],
        )
862
863
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
864
865
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
866
867
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
868
869
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
870
871
872
873
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
874
875
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
876
877
        parallel_group.add_argument(
            "--expert-placement-strategy",
878
879
            **parallel_kwargs["expert_placement_strategy"],
        )
880

881
        parallel_group.add_argument(
882
            "--max-parallel-loading-workers",
883
884
            **parallel_kwargs["max_parallel_loading_workers"],
        )
885
        parallel_group.add_argument(
886
887
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
888
        parallel_group.add_argument(
889
            "--disable-custom-all-reduce",
890
891
892
893
894
895
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
896

897
898
899
900
901
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
902
        )
903
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
904
905
906
907
908
909
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
910
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
911
912
913
914
915
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
916
917
918
919
920
            "--enable-prefix-caching",
            **{
                **cache_kwargs["enable_prefix_caching"],
                "default": None,
            },
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
938
939
940
        cache_group.add_argument(
            "--mamba-block-size", **cache_kwargs["mamba_block_size"]
        )
941
942
943
944
945
946
        cache_group.add_argument(
            "--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
        )
        cache_group.add_argument(
            "--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
        )
947

948
        # Multimodal related configs
949
950
951
952
953
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
954
        multimodal_group.add_argument(
955
956
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
957
958
959
        multimodal_group.add_argument(
            "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
        )
960
961
962
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
963
964
965
966
967
968
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
969
        multimodal_group.add_argument(
970
971
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
972
973
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
974
975
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
976
        multimodal_group.add_argument(
977
978
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
979
980
981
982
        multimodal_group.add_argument(
            "--mm-encoder-attn-backend",
            **multimodal_kwargs["mm_encoder_attn_backend"],
        )
983
984
985
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
986
        multimodal_group.add_argument(
987
988
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
989

990
        multimodal_group.add_argument(
991
992
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
993

994
        # LoRA related configs
995
996
997
998
999
1000
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
1001
            "--enable-lora",
1002
            action=argparse.BooleanOptionalAction,
1003
1004
            help="If True, enable handling of LoRA adapters.",
        )
1005
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
1006
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
1007
        lora_group.add_argument(
1008
            "--lora-dtype",
1009
1010
            **lora_kwargs["lora_dtype"],
        )
1011
1012
1013
1014
        lora_group.add_argument(
            "--enable-tower-connector-lora",
            **lora_kwargs["enable_tower_connector_lora"],
        )
1015
1016
1017
1018
1019
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
1020

1021
1022
1023
1024
1025
1026
1027
1028
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
1029
1030
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
1031
        observability_group.add_argument(
1032
1033
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
1034
1035
1036
1037
1038
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
1039
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
1040
1041
1042
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
1043
1044
            **observability_kwargs["collect_detailed_traces"],
        )
1045
1046
1047
1048
1049
1050
1051
        observability_group.add_argument(
            "--kv-cache-metrics", **observability_kwargs["kv_cache_metrics"]
        )
        observability_group.add_argument(
            "--kv-cache-metrics-sample",
            **observability_kwargs["kv_cache_metrics_sample"],
        )
1052
1053
1054
1055
        observability_group.add_argument(
            "--cudagraph-metrics",
            **observability_kwargs["cudagraph_metrics"],
        )
1056
1057
1058
1059
        observability_group.add_argument(
            "--enable-layerwise-nvtx-tracing",
            **observability_kwargs["enable_layerwise_nvtx_tracing"],
        )
1060
1061
1062
1063
        observability_group.add_argument(
            "--enable-mfu-metrics",
            **observability_kwargs["enable_mfu_metrics"],
        )
1064
1065
1066
1067
        observability_group.add_argument(
            "--enable-logging-iteration-details",
            **observability_kwargs["enable_logging_iteration_details"],
        )
1068

1069
1070
1071
1072
1073
1074
1075
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
1076
1077
1078
1079
1080
            "--max-num-batched-tokens",
            **{
                **scheduler_kwargs["max_num_batched_tokens"],
                "default": None,
            },
1081
        )
1082
        scheduler_group.add_argument(
1083
1084
1085
1086
1087
            "--max-num-seqs",
            **{
                **scheduler_kwargs["max_num_seqs"],
                "default": None,
            },
1088
1089
1090
1091
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
1092
1093
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
1094
1095
            **scheduler_kwargs["max_long_partial_prefills"],
        )
1096
1097
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
1098
1099
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
1100
1101
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
1102
        scheduler_group.add_argument(
1103
1104
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
1105
        scheduler_group.add_argument(
1106
1107
1108
1109
1110
            "--enable-chunked-prefill",
            **{
                **scheduler_kwargs["enable_chunked_prefill"],
                "default": None,
            },
1111
1112
1113
1114
1115
1116
1117
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1118
1119
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1120
1121
1122
1123
1124
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1125
1126
1127
        scheduler_group.add_argument(
            "--stream-interval", **scheduler_kwargs["stream_interval"]
        )
1128

1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
        # Compilation arguments
        compilation_kwargs = get_kwargs(CompilationConfig)
        compilation_group = parser.add_argument_group(
            title="CompilationConfig",
            description=CompilationConfig.__doc__,
        )
        compilation_group.add_argument(
            "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
        )
        compilation_group.add_argument(
            "--max-cudagraph-capture-size",
            **compilation_kwargs["max_cudagraph_capture_size"],
        )

1143
        # vLLM arguments
1144
        vllm_kwargs = get_kwargs(VllmConfig)
1145
1146
1147
1148
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1149
1150
1151
1152
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1153
1154
1155
1156
1157
1158
1159
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
1160
1161
1162
        vllm_group.add_argument(
            "--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
        )
1163
        vllm_group.add_argument(
1164
            "--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
1165
        )
1166
1167
1168
        vllm_group.add_argument(
            "--attention-config", "-ac", **vllm_kwargs["attention_config"]
        )
1169
1170
1171
1172
1173
1174
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1175
        vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
1176
1177
1178
1179
        vllm_group.add_argument(
            "--optimization-level", **vllm_kwargs["optimization_level"]
        )

1180
        # Other arguments
1181
1182
1183
1184
1185
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1186

1187
1188
1189
1190
1191
1192
        parser.add_argument(
            "--aggregate-engine-logging",
            action="store_true",
            help="Log aggregate rather than per-engine statistics "
            "when using data parallelism.",
        )
1193
        return parser
1194
1195

    @classmethod
1196
    def from_cli_args(cls, args: argparse.Namespace):
1197
1198
1199
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1200
1201
1202
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1203
        return engine_args
1204

1205
    def create_model_config(self) -> ModelConfig:
1206
1207
        # gguf file needs a specific model loader
        if is_gguf(self.model):
1208
1209
            self.quantization = self.load_format = "gguf"

1210
1211
1212
1213
1214
1215
1216
        if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
            logger.warning(
                "The global random seed is set to %d. Since "
                "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
                "affect the random state of the Python process that "
                "launched vLLM.",
                self.seed,
1217
1218
            )

1219
        return ModelConfig(
1220
            model=self.model,
1221
            model_weights=self.model_weights,
1222
            hf_config_path=self.hf_config_path,
1223
1224
            runner=self.runner,
            convert=self.convert,
1225
1226
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
1227
            trust_remote_code=self.trust_remote_code,
1228
1229
            allowed_local_media_path=self.allowed_local_media_path,
            allowed_media_domains=self.allowed_media_domains,
1230
1231
1232
1233
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
1234
            hf_token=self.hf_token,
1235
            hf_overrides=self.hf_overrides,
1236
            tokenizer_revision=self.tokenizer_revision,
1237
1238
            max_model_len=self.max_model_len,
            quantization=self.quantization,
1239
            allow_deprecated_quantization=self.allow_deprecated_quantization,
1240
            enforce_eager=self.enforce_eager,
1241
            enable_return_routed_experts=self.enable_return_routed_experts,
1242
            max_logprobs=self.max_logprobs,
1243
            logprobs_mode=self.logprobs_mode,
1244
            disable_sliding_window=self.disable_sliding_window,
1245
            disable_cascade_attn=self.disable_cascade_attn,
1246
            skip_tokenizer_init=self.skip_tokenizer_init,
1247
            enable_prompt_embeds=self.enable_prompt_embeds,
1248
            served_model_name=self.served_model_name,
1249
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1250
            enable_mm_embeds=self.enable_mm_embeds,
1251
            interleave_mm_strings=self.interleave_mm_strings,
1252
            media_io_kwargs=self.media_io_kwargs,
1253
            skip_mm_profiling=self.skip_mm_profiling,
1254
            config_format=self.config_format,
1255
            mm_processor_kwargs=self.mm_processor_kwargs,
1256
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1257
            mm_processor_cache_type=self.mm_processor_cache_type,
1258
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1259
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1260
            mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1261
            pooler_config=self.pooler_config,
1262
            logits_processor_pattern=self.logits_processor_pattern,
1263
            generation_config=self.generation_config,
1264
            override_generation_config=self.override_generation_config,
1265
            enable_sleep_mode=self.enable_sleep_mode,
1266
            model_impl=self.model_impl,
1267
            override_attention_dtype=self.override_attention_dtype,
1268
            logits_processors=self.logits_processors,
1269
            video_pruning_rate=self.video_pruning_rate,
1270
            io_processor_plugin=self.io_processor_plugin,
1271
        )
1272

1273
    def validate_tensorizer_args(self):
1274
1275
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1276
1277
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1278
1279
1280
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1281

1282
    def create_load_config(self) -> LoadConfig:
1283
1284
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1285

1286
1287
1288
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1289
1290
                    self.model_loader_extra_config.to_serializable()
                )
1291
            self.model_loader_extra_config["tensorizer_config"] = {}
1292
1293
1294
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1295
            self.validate_tensorizer_args()
1296

1297
1298
1299
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1300
            safetensors_load_strategy=self.safetensors_load_strategy,
1301
1302
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1303
            use_tqdm_on_load=self.use_tqdm_on_load,
1304
            pt_load_map_location=self.pt_load_map_location,
1305
        )
1306

1307
1308
1309
1310
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
1311
    ) -> SpeculativeConfig | None:
1312
1313
1314
1315
1316
1317
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1318
        dictionary from the engine.
1319
1320
        """
        if self.speculative_config is None:
1321
            return None
1322

1323
1324
1325
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1326
1327
1328
1329
1330
1331
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
            }
        )
1332
        return SpeculativeConfig(**self.speculative_config)
1333

1334
1335
    def create_engine_config(
        self,
1336
        usage_context: UsageContext | None = None,
1337
        headless: bool = False,
1338
1339
1340
1341
    ) -> VllmConfig:
        """
        Create the VllmConfig.

1342
        NOTE: If VllmConfig is incompatible, we raise an error.
1343
        """
1344
        current_platform.pre_register_and_update()
1345

1346
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1347

1348
1349
        # Check if the model is a speculator and override model/tokenizer/config
        # BEFORE creating ModelConfig, so the config is created with the target model
1350
1351
1352
1353
        # Skip speculator detection for cloud storage models (eg: S3, GCS) since
        # HuggingFace cannot load configs directly from S3 URLs. S3 models can still
        # use speculators with explicit --speculative-config.
        if not is_cloud_storage(self.model):
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
            (self.model, self.tokenizer, self.speculative_config) = (
                maybe_override_with_speculators(
                    model=self.model,
                    tokenizer=self.tokenizer,
                    revision=self.revision,
                    trust_remote_code=self.trust_remote_code,
                    vllm_speculative_config=self.speculative_config,
                )
            )

1364
        model_config = self.create_model_config()
1365
        self.model = model_config.model
1366
        self.model_weights = model_config.model_weights
1367
1368
        self.tokenizer = model_config.tokenizer

1369
        self._check_feature_supported(model_config)
1370
1371
1372
1373
        self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
        self._set_default_max_num_seqs_and_batched_tokens_args(
            usage_context, model_config
        )
1374

1375
        sliding_window: int | None = None
1376
1377
1378
1379
1380
1381
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1382
1383
1384
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1385
        # reuses the GPUs of TP group, and split one TP group into
1386
        # tp_size//dcp_size DCP groups.
1387
        assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, (
1388
1389
1390
1391
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1392
1393
1394
1395
1396
        # Resolve "auto" kv_cache_dtype to actual value from model config
        resolved_cache_dtype = resolve_kv_cache_dtype_string(
            self.kv_cache_dtype, model_config
        )

1397
        cache_config = CacheConfig(
1398
            block_size=self.block_size,
1399
            gpu_memory_utilization=self.gpu_memory_utilization,
1400
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1401
            swap_space=self.swap_space,
1402
            cache_dtype=resolved_cache_dtype,
1403
            is_attention_free=model_config.is_attention_free,
1404
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1405
            sliding_window=sliding_window,
1406
            enable_prefix_caching=self.enable_prefix_caching,
1407
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1408
            cpu_offload_gb=self.cpu_offload_gb,
1409
            calculate_kv_scales=self.calculate_kv_scales,
1410
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1411
1412
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1413
            mamba_block_size=self.mamba_block_size,
1414
1415
            kv_offloading_size=self.kv_offloading_size,
            kv_offloading_backend=self.kv_offloading_backend,
1416
        )
1417

1418
1419
1420
1421
1422
1423
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1424

1425
            ray_runtime_env = ray.get_runtime_context().runtime_env
1426
1427
1428
1429
1430
1431
1432
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1433

1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1445
        assert not headless or not self.data_parallel_hybrid_lb, (
1446
1447
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
        assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
            "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
        )
        assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
            "nnodes > 1 is only supported with data_parallel_backend=mp"
        )
        inferred_data_parallel_rank = 0
        if self.nnodes > 1:
            world_size = (
                self.data_parallel_size
                * self.pipeline_parallel_size
                * self.tensor_parallel_size
            )
            world_size_within_dp = (
                self.pipeline_parallel_size * self.tensor_parallel_size
            )
            local_world_size = world_size // self.nnodes
            assert world_size % self.nnodes == 0, (
                f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
            )
            assert self.node_rank < self.nnodes, (
                f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
            )
            inferred_data_parallel_rank = (
                self.node_rank * local_world_size
            ) // world_size_within_dp
            if self.data_parallel_size > 1 and self.data_parallel_external_lb:
                self.data_parallel_rank = inferred_data_parallel_rank
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d for external lb",
                    self.data_parallel_rank,
                    self.node_rank,
                )
            elif self.data_parallel_size_local is None:
                # Infer data parallel size local for internal dplb:
                self.data_parallel_size_local = max(
                    local_world_size // world_size_within_dp, 1
                )
        data_parallel_external_lb = (
            self.data_parallel_external_lb or self.data_parallel_rank is not None
        )
1489
        # Local DP rank = 1, use pure-external LB.
1490
        if data_parallel_external_lb:
1491
            assert self.data_parallel_rank is not None, (
1492
                "data_parallel_rank or node_rank must be specified if "
1493
1494
                "data_parallel_external_lb is enable."
            )
1495
            assert self.data_parallel_size_local in (1, None), (
1496
1497
                "data_parallel_size_local must be 1 or None when data_parallel_rank "
                "is set"
1498
            )
1499
            data_parallel_size_local = 1
1500
1501
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1502
1503
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1504
1505
1506
1507
1508
1509
1510

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
1511
1512
1513
1514
1515
                logger.warning(
                    "data_parallel_hybrid_lb is not eligible when "
                    "data_parallel_size_local = 1, autoswitch to "
                    "data_parallel_external_lb."
                )
1516
1517
1518
1519
1520
1521
1522
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

1523
1524
1525
1526
1527
1528
1529
1530
1531
            self.data_parallel_rank = (
                self.data_parallel_start_rank or inferred_data_parallel_rank
            )
            if self.nnodes > 1:
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d",
                    self.data_parallel_rank,
                    self.node_rank,
                )
1532
        else:
1533
            assert not self.data_parallel_hybrid_lb, (
1534
1535
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1536

1537
1538
1539
1540
1541
1542
1543
1544
1545
            if self.data_parallel_backend == "ray" and (
                envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
            ):
                # Data parallel size defaults to 1 if DP ranks are spanning
                # multiple nodes
                data_parallel_size_local = 1
            else:
                # Otherwise local DP size defaults to global DP size if not set
                data_parallel_size_local = self.data_parallel_size
1546
1547
1548

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1549
1550
1551
1552
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1553
1554
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1555
1556
1557
1558
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1559
1560
                    self.data_parallel_backend,
                )
1561
1562
1563
                data_parallel_address = (
                    self.master_addr or ParallelConfig.data_parallel_master_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1564
1565
        else:
            data_parallel_address = self.data_parallel_address
1566
1567
1568

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1569
        data_parallel_rpc_port = (
1570
            self.data_parallel_rpc_port
1571
1572
1573
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1574

1575
1576
1577
1578
        if self.tokens_only and not model_config.skip_tokenizer_init:
            model_config.skip_tokenizer_init = True
            logger.info("Skipping tokenizer initialization for tokens-only mode.")

1579
        parallel_config = ParallelConfig(
1580
1581
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1582
            prefill_context_parallel_size=self.prefill_context_parallel_size,
1583
            data_parallel_size=self.data_parallel_size,
1584
1585
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1586
            data_parallel_size_local=data_parallel_size_local,
1587
1588
1589
1590
            master_addr=self.master_addr,
            master_port=self.master_port,
            nnodes=self.nnodes,
            node_rank=self.node_rank,
1591
1592
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1593
            data_parallel_backend=self.data_parallel_backend,
1594
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1595
            is_moe_model=model_config.is_moe,
1596
            enable_expert_parallel=self.enable_expert_parallel,
1597
            all2all_backend=self.all2all_backend,
1598
            enable_dbo=self.enable_dbo,
1599
            ubatch_size=self.ubatch_size,
1600
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1601
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1602
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1603
            enable_eplb=self.enable_eplb,
1604
            eplb_config=self.eplb_config,
1605
            expert_placement_strategy=self.expert_placement_strategy,
1606
1607
1608
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1609
            ray_runtime_env=ray_runtime_env,
1610
            placement_group=placement_group,
1611
1612
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1613
            worker_extension_cls=self.worker_extension_cls,
1614
            decode_context_parallel_size=self.decode_context_parallel_size,
1615
            dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
1616
            cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
1617
1618
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1619
        )
1620

1621
        speculative_config = self.create_speculative_config(
1622
1623
1624
1625
            target_model_config=model_config,
            target_parallel_config=parallel_config,
        )

1626
        scheduler_config = SchedulerConfig(
1627
            runner_type=model_config.runner_type,
1628
1629
1630
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1631
            enable_chunked_prefill=self.enable_chunked_prefill,
1632
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1633
            is_multimodal_model=model_config.is_multimodal_model,
1634
            is_encoder_decoder=model_config.is_encoder_decoder,
1635
            policy=self.scheduling_policy,
1636
            scheduler_cls=self.scheduler_cls,
1637
1638
1639
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1640
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1641
            async_scheduling=self.async_scheduling,
1642
            stream_interval=self.stream_interval,
1643
        )
1644

1645
1646
1647
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_dtype=self.lora_dtype,
1658
                enable_tower_connector_lora=self.enable_tower_connector_lora,
1659
1660
1661
1662
1663
1664
1665
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1666

1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
        if (
            lora_config is not None
            and speculative_config is not None
            and scheduler_config.max_num_batched_tokens
            < (
                scheduler_config.max_num_seqs
                * (speculative_config.num_speculative_tokens + 1)
            )
        ):
            raise ValueError(
                "Consider increasing max_num_batched_tokens or "
                "decreasing num_speculative_tokens"
            )

1681
1682
1683
1684
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1685
1686
1687
1688
1689
1690
1691
1692
        # Attention config overrides
        attention_config = copy.deepcopy(self.attention_config)
        if self.attention_backend is not None:
            if attention_config.backend is not None:
                raise ValueError(
                    "attention_backend and attention_config.backend "
                    "are mutually exclusive"
                )
1693
1694
1695
1696
1697
1698
1699
            # Convert string to enum if needed (CLI parsing returns a string)
            if isinstance(self.attention_backend, str):
                attention_config.backend = AttentionBackendEnum[
                    self.attention_backend.upper()
                ]
            else:
                attention_config.backend = self.attention_backend
1700

1701
        load_config = self.create_load_config()
1702

1703
1704
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1705
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1706

1707
1708
1709
1710
1711
        if self.reasoning_parser_plugin:
            self.structured_outputs_config.reasoning_parser_plugin = (
                self.reasoning_parser_plugin
            )

1712
        observability_config = ObservabilityConfig(
1713
            show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
1714
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1715
            collect_detailed_traces=self.collect_detailed_traces,
1716
1717
            kv_cache_metrics=self.kv_cache_metrics,
            kv_cache_metrics_sample=self.kv_cache_metrics_sample,
1718
            cudagraph_metrics=self.cudagraph_metrics,
1719
            enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
1720
            enable_mfu_metrics=self.enable_mfu_metrics,
1721
            enable_mm_processor_stats=self.enable_mm_processor_stats,
1722
            enable_logging_iteration_details=self.enable_logging_iteration_details,
1723
        )
1724

1725
        # Compilation config overrides
1726
        compilation_config = copy.deepcopy(self.compilation_config)
1727
        if self.cudagraph_capture_sizes is not None:
1728
            if compilation_config.cudagraph_capture_sizes is not None:
1729
1730
1731
1732
                raise ValueError(
                    "cudagraph_capture_sizes and compilation_config."
                    "cudagraph_capture_sizes are mutually exclusive"
                )
1733
            compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes
1734
        if self.max_cudagraph_capture_size is not None:
1735
            if compilation_config.max_cudagraph_capture_size is not None:
1736
1737
1738
1739
                raise ValueError(
                    "max_cudagraph_capture_size and compilation_config."
                    "max_cudagraph_capture_size are mutually exclusive"
                )
1740
            compilation_config.max_cudagraph_capture_size = (
1741
1742
                self.max_cudagraph_capture_size
            )
1743
        config = VllmConfig(
1744
1745
1746
1747
1748
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
1749
1750
            load_config=load_config,
            attention_config=attention_config,
1751
1752
            lora_config=lora_config,
            speculative_config=speculative_config,
1753
            structured_outputs_config=self.structured_outputs_config,
1754
            observability_config=observability_config,
1755
            compilation_config=compilation_config,
1756
            kv_transfer_config=self.kv_transfer_config,
1757
            kv_events_config=self.kv_events_config,
1758
            ec_transfer_config=self.ec_transfer_config,
1759
            profiler_config=self.profiler_config,
1760
            additional_config=self.additional_config,
1761
            optimization_level=self.optimization_level,
1762
        )
1763

1764
1765
        return config

1766
1767
    def _check_feature_supported(self, model_config: ModelConfig):
        """Raise an error if the feature is not supported."""
1768
        if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
1769
            _raise_unsupported_error(feature_name="--logits-processor-pattern")
1770
1771

        # No Concurrent Partial Prefills so far.
1772
1773
1774
1775
1776
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
1777
            _raise_unsupported_error(feature_name="Concurrent Partial Prefill")
1778

1779
        if self.pipeline_parallel_size > 1:
1780
1781
1782
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1783
            if not supports_pp and self.distributed_executor_backend not in (
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
1794
                _raise_unsupported_error(feature_name=name)
1795

1796
1797
1798
1799
1800
1801
    @classmethod
    def get_batch_defaults(
        cls,
        world_size: int,
    ) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
        from vllm.usage.usage_lib import UsageContext
1802

1803
1804
        default_max_num_batched_tokens: dict[UsageContext | None, int]
        default_max_num_seqs: dict[UsageContext | None, int]
1805

1806
1807
        # When no user override, set the default values based on the usage
        # context.
1808
        # Use different default values for different hardware.
1809
1810
1811
1812
1813
1814
1815

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1816
            device_memory = current_platform.get_device_total_memory()
1817
            device_name = current_platform.get_device_name().lower()
1818
1819
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1820
            device_memory = 0
1821
            device_name = ""
1822

1823
1824
1825
1826
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1827
            # For GPUs like H100 and MI300x, use larger default values.
1828
1829
1830
1831
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1832
1833
1834
1835
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1836
1837
1838
1839
1840
1841
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1842
1843
1844
1845
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1846

1847
1848
        # tpu specific default values.
        if current_platform.is_tpu():
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
            chip_name = current_platform.get_device_name()

            if chip_name == "V6E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 2048,
                    UsageContext.OPENAI_API_SERVER: 1024,
                }
            elif chip_name == "V5E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 1024,
                    UsageContext.OPENAI_API_SERVER: 512,
                }
            elif chip_name == "V5P":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 512,
                    UsageContext.OPENAI_API_SERVER: 256,
                }
1866

1867
1868
1869
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
1870
1871
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1872
1873
            }
            default_max_num_seqs = {
1874
1875
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1876
1877
            }

1878
1879
        return default_max_num_batched_tokens, default_max_num_seqs

1880
1881
    def _set_default_chunked_prefill_and_prefix_caching_args(
        self, model_config: ModelConfig
1882
    ) -> None:
1883
1884
        default_chunked_prefill = model_config.is_chunked_prefill_supported
        default_prefix_caching = model_config.is_prefix_caching_supported
1885
1886
1887
1888
1889
1890
1891
1892

        if self.enable_chunked_prefill is None:
            self.enable_chunked_prefill = default_chunked_prefill

            logger.debug(
                "%s chunked prefill by default",
                "Enabling" if default_chunked_prefill else "Disabling",
            )
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
        elif (
            model_config.runner_type == "generate"
            and not self.enable_chunked_prefill
            and default_chunked_prefill
        ):
            logger.warning_once(
                "This model does not officially support disabling chunked prefill. "
                "Disabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
                scope="local",
            )
1904
1905
1906
1907
        elif (
            model_config.runner_type == "pooling"
            and self.enable_chunked_prefill
            and not default_chunked_prefill
1908
        ):
1909
            logger.warning_once(
1910
1911
1912
                "This model does not officially support chunked prefill. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
1913
                scope="local",
1914
1915
1916
1917
1918
            )

        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = default_prefix_caching

1919
            logger.debug(
1920
1921
1922
1923
1924
1925
1926
1927
                "%s prefix caching by default",
                "Enabling" if default_prefix_caching else "Disabling",
            )
        elif (
            model_config.runner_type == "pooling"
            and self.enable_prefix_caching
            and not default_prefix_caching
        ):
1928
            logger.warning_once(
1929
1930
1931
                "This model does not officially support prefix caching. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
1932
                scope="local",
1933
1934
            )

1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
        # Disable chunked prefill and prefix caching for:
        # POWER (ppc64le)/s390x/RISCV CPUs in V1
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.S390X,
            CpuArchEnum.RISCV,
        ):
            logger.info(
                "Chunked prefill is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
            self.enable_chunked_prefill = False
            logger.info(
                "Prefix caching is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
            self.enable_prefix_caching = False

    def _set_default_max_num_seqs_and_batched_tokens_args(
1956
1957
1958
        self,
        usage_context: UsageContext | None,
        model_config: ModelConfig,
1959
    ):
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
        world_size = self.pipeline_parallel_size * self.tensor_parallel_size
        (
            default_max_num_batched_tokens,
            default_max_num_seqs,
        ) = self.get_batch_defaults(world_size)

        orig_max_num_batched_tokens = self.max_num_batched_tokens
        orig_max_num_seqs = self.max_num_seqs

        if self.max_num_batched_tokens is None:
            self.max_num_batched_tokens = default_max_num_batched_tokens.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )

        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
            )

        if orig_max_num_batched_tokens is None:
            if not self.enable_chunked_prefill:
                # If max_model_len is too short, use the default for higher throughput.
                self.max_num_batched_tokens = max(
                    model_config.max_model_len,
                    self.max_num_batched_tokens,
                )

            # When using default settings,
            # Ensure max_num_batched_tokens does not exceed model limit.
            # Some models (e.g., Whisper) have embeddings tied to max length.
            self.max_num_batched_tokens = min(
                self.max_num_seqs * model_config.max_model_len,
1994
1995
                self.max_num_batched_tokens,
            )
1996

1997
1998
1999
2000
            logger.debug(
                "Defaulting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens,
                usage_context.value if usage_context else None,
2001
            )
2002

2003
2004
2005
2006
        if orig_max_num_seqs is None:
            assert self.max_num_batched_tokens is not None  # For type checking
            self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

2007
            logger.debug(
2008
                "Defaulting max_num_seqs to %d for %s usage context.",
2009
                self.max_num_seqs,
2010
                usage_context.value if usage_context else None,
2011
            )
2012

2013

2014
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
2015
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
2016
    """Arguments for asynchronous vLLM engine."""
2017

2018
2019
    enable_log_requests: bool = False

2020
    @staticmethod
2021
2022
2023
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
2024
        # Initialize plugin to update the parser, for example, The plugin may
2025
        # add a new kind of quantization method to --quantization argument or
2026
2027
        # a new device to --device argument.
        load_general_plugins()
2028
2029
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
2043
        current_platform.pre_register_and_update(parser)
2044
        return parser
2045
2046


2047
2048
2049
2050
2051
2052
def _raise_unsupported_error(feature_name: str):
    msg = (
        f"{feature_name} is not supported. We recommend to "
        f"remove {feature_name} from your config."
    )
    raise NotImplementedError(msg)
2053
2054


2055
def human_readable_int(value: str) -> int:
2056
2057
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
2058

2059
2060
2061
2062
2063
2064
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
2065

2066
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
2067
2068
    if match:
        decimal_multiplier = {
2069
2070
2071
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
2072
            "t": 10**12,
2073
2074
        }
        binary_multiplier = {
2075
2076
2077
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
2078
            "T": 2**40,
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
2091
2092
2093
2094
2095
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
2096
2097
2098

    # Regular plain number.
    return int(value)
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117


def human_readable_int_or_auto(value: str) -> int:
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    Also accepts -1 or 'auto' as a special value for auto-detection.

    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    - '-1' or 'auto' -> -1 (special value for auto-detection)
    """
    value = value.strip()

    if value == "-1" or value.lower() == "auto":
        return -1

    return human_readable_int(value)