arg_utils.py 88.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5

6
import argparse
7
import copy
8
import dataclasses
9
import functools
10
import json
11
import sys
12
from collections.abc import Callable
13
from dataclasses import MISSING, dataclass, fields, is_dataclass
14
from itertools import permutations
15
from types import UnionType
16
17
18
19
20
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Literal,
21
    TypeAlias,
22
23
24
25
26
27
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
28

29
import huggingface_hub
30
import regex as re
31
import torch
32
from pydantic import TypeAdapter, ValidationError
33
from pydantic.fields import FieldInfo
34
from typing_extensions import TypeIs
35

36
import vllm.envs as envs
37
from vllm.config import (
38
    AttentionConfig,
39
40
41
42
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
43
    ECTransferConfig,
44
45
46
47
48
49
    EPLBConfig,
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
    ModelConfig,
50
    MultiModalConfig,
51
52
53
    ObservabilityConfig,
    ParallelConfig,
    PoolerConfig,
54
    ProfilerConfig,
55
56
57
58
59
60
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
    VllmConfig,
    get_attr_docs,
)
61
62
63
64
from vllm.config.cache import (
    BlockSize,
    CacheDType,
    KVOffloadingBackend,
65
    MambaCacheMode,
66
67
68
    MambaDType,
    PrefixCachingHashAlgo,
)
69
70
71
72
73
74
75
76
77
78
79
80
81
from vllm.config.device import Device
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
    TokenizerMode,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
82
from vllm.config.utils import get_field
83
from vllm.config.vllm import OptimizationLevel
84
from vllm.logger import init_logger, suppress_logging
85
from vllm.platforms import CpuArchEnum, current_platform
86
from vllm.plugins import load_general_plugins
87
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
88
89
90
91
from vllm.transformers_utils.config import (
    is_interleaved,
    maybe_override_with_speculators,
)
92
from vllm.transformers_utils.gguf_utils import is_gguf
93
from vllm.transformers_utils.repo_utils import get_model_path
94
from vllm.transformers_utils.utils import is_cloud_storage
95
from vllm.utils.argparse_utils import FlexibleArgumentParser
96
from vllm.utils.mem_constants import GiB_bytes
97
from vllm.utils.network_utils import get_ip
98
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
99
from vllm.v1.attention.backends.registry import AttentionBackendEnum
100
from vllm.v1.sample.logits_processor import LogitsProcessor
101

102
103
if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
104
    from vllm.model_executor.model_loader import LoadFormats
105
    from vllm.usage.usage_lib import UsageContext
106
    from vllm.v1.executor import Executor
107
else:
108
    Executor = Any
109
    QuantizationMethods = Any
110
    LoadFormats = Any
111
112
    UsageContext = Any

113

114
115
logger = init_logger(__name__)

116
117
# object is used to allow for special typing forms
T = TypeVar("T")
118
119
TypeHint: TypeAlias = type[Any] | object
TypeHintT: TypeAlias = type[T] | object
120

121

122
123
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
124
125
126
127
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
128
129
                f"Value {val} cannot be converted to {return_type}."
            ) from e
130

131
132
133
    return _parse_type


134
135
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]:
    def _optional_type(val: str) -> T | None:
136
137
138
139
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

140
    return _optional_type
141
142


143
def union_dict_and_str(val: str) -> str | dict[str, str] | None:
144
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
145
        return str(val)
146
    return optional_type(json.loads)(val)
147
148


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


164
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
165
166
167
168
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
169
    type_hint = get_type(type_hints, Literal)
170
171
172
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
173
        raise ValueError(
174
            "All options must be of the same type. "
175
176
            f"Got {options} with types {[type(c) for c in options]}"
        )
177
178
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
179
180


181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]:
    type_hint = get_type(type_hints, type)
    types = get_args(type_hint)
    elem_type = types[0]

    # Handle Ellipsis
    assert all(t is elem_type for t in types if t is not Ellipsis), (
        f"All non-Ellipsis elements must be of the same type. Got {types}."
    )

    # Handle Union types
    if get_origin(elem_type) in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
        assert str in get_args(elem_type), (
            "If element can have multiple types, one must be 'str' "
            f"(i.e. 'list[int | str]'). Got {elem_type}."
        )
        elem_type = str

    return {
        "type": elem_type,
        "nargs": "+" if type is not tuple or Ellipsis in types else len(types),
    }


206
207
208
209
210
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


211
212
213
214
215
216
217
218
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
219
220
    elif origin in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
221
222
223
224
225
226
227
228
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


229
NEEDS_HELP = (
230
231
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
232
233
234
235
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


236
@functools.lru_cache(maxsize=30)
237
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
238
239
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
240
241
    kwargs = {}
    for field in fields(cls):
242
        # Get the set of possible types for the field
243
        type_hints: set[TypeHint] = get_type_hints(field.type)
244
245
246
247
248

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

249
        # Get the default value of the field
250
251
        if field.default is not MISSING:
            default = field.default
252
253
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
254
255
256
257
258
259
260
                if default.default_factory is None:
                    default = default.default
                else:
                    # VllmConfig's Fields have default_factory set to config classes.
                    # These could emit logs on init, which would be confusing.
                    with suppress_logging():
                        default = default.default_factory()
261
        elif field.default_factory is not MISSING:
262
263
264
265
            default = field.default_factory()

        # Get the help text for the field
        name = field.name
266
        help = cls_docs.get(name, "").strip()
267
268
269
270
271
272
273
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
274
275
276
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
277
        if dataclass_cls is not None:
278
279
280
281
282
283
284
285

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
286
            kwargs[name]["help"] += f"\n\n{json_tip}"
287
        elif contains_type(type_hints, bool):
288
289
290
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
291
            kwargs[name].update(literal_to_kwargs(type_hints))
292
        elif contains_type(type_hints, tuple):
293
            kwargs[name].update(collection_to_kwargs(type_hints, tuple))
294
        elif contains_type(type_hints, list):
295
296
297
            kwargs[name].update(collection_to_kwargs(type_hints, list))
        elif contains_type(type_hints, set):
            kwargs[name].update(collection_to_kwargs(type_hints, set))
298
        elif contains_type(type_hints, int):
299
300
301
302
            if name == "max_model_len":
                kwargs[name]["type"] = human_readable_int_or_auto
                kwargs[name]["help"] += f"\n\n{human_readable_int_or_auto.__doc__}"
            elif name in ("max_num_batched_tokens", "kv_cache_memory_bytes"):
303
                kwargs[name]["type"] = human_readable_int
304
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
305
306
            else:
                kwargs[name]["type"] = int
307
308
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
309
310
311
312
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
313
            kwargs[name]["type"] = union_dict_and_str
314
        elif contains_type(type_hints, dict):
315
            kwargs[name]["type"] = parse_type(json.loads)
316
            kwargs[name]["help"] += f"\n\n{json_tip}"
317
318
319
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
320
321
            kwargs[name]["type"] = str
        else:
322
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
323

324
325
326
327
328
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

329
330
331
332
333
334
335
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
336
337


338
def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
339
340
    """Return argparse kwargs for the given Config dataclass.

341
342
343
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

344
345
346
347
348
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))
zhuwenwen's avatar
zhuwenwen committed
349
   
350

351
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
352
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
353
    """Arguments for vLLM engine."""
354

355
    model: str = ModelConfig.model
356
    enable_return_routed_experts: bool = ModelConfig.enable_return_routed_experts
357
    model_weights: str = ModelConfig.model_weights
358
359
360
    served_model_name: str | list[str] | None = ModelConfig.served_model_name
    tokenizer: str | None = ModelConfig.tokenizer
    hf_config_path: str | None = ModelConfig.hf_config_path
361
362
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
363
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
364
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
365
    tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
366
367
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
368
369
    allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
    download_dir: str | None = LoadConfig.download_dir
370
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
371
    load_format: str | LoadFormats = LoadConfig.load_format
372
373
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
374
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
375
    seed: int = ModelConfig.seed
376
    max_model_len: int | None = ModelConfig.max_model_len
377
378
379
380
381
382
    cudagraph_capture_sizes: list[int] | None = (
        CompilationConfig.cudagraph_capture_sizes
    )
    max_cudagraph_capture_size: int | None = get_field(
        CompilationConfig, "max_cudagraph_capture_size"
    )
383
384
385
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
386
    distributed_executor_backend: (
387
        str | DistributedExecutorBackend | type[Executor] | None
388
    ) = ParallelConfig.distributed_executor_backend
389
    # number of P/D disaggregation (or other disaggregation) workers
390
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
391
392
393
394
    master_addr: str = ParallelConfig.master_addr
    master_port: int = ParallelConfig.master_port
    nnodes: int = ParallelConfig.nnodes
    node_rank: int = ParallelConfig.node_rank
395
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
396
    prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
397
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
398
    dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
399
    cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
400
    data_parallel_size: int = ParallelConfig.data_parallel_size
401
402
403
404
405
    data_parallel_rank: int | None = None
    data_parallel_start_rank: int | None = None
    data_parallel_size_local: int | None = None
    data_parallel_address: str | None = None
    data_parallel_rpc_port: int | None = None
406
    data_parallel_hybrid_lb: bool = False
407
    data_parallel_external_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
408
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
409
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
410
    all2all_backend: str = ParallelConfig.all2all_backend
411
    enable_dbo: bool = ParallelConfig.enable_dbo
412
    ubatch_size: int = ParallelConfig.ubatch_size
413
414
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
415
    disable_nccl_for_dp_synchronization: bool | None = (
416
417
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
418
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
419
    enable_eplb: bool = ParallelConfig.enable_eplb
420
    expert_placement_strategy: ExpertPlacementStrategy = (
421
        ParallelConfig.expert_placement_strategy
422
    )
423
424
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
425
    max_parallel_loading_workers: int | None = (
426
427
        ParallelConfig.max_parallel_loading_workers
    )
428
    block_size: BlockSize | None = CacheConfig.block_size
429
    enable_prefix_caching: bool | None = None
430
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
431
        CacheConfig.prefix_caching_hash_algo
432
    )
433
434
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
435
436
437
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
438
    kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
439
    max_num_batched_tokens: int | None = None
440
441
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
442
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
443
    max_num_seqs: int | None = None
444
    max_logprobs: int = ModelConfig.max_logprobs
445
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
446
    disable_log_stats: bool = False
447
    aggregate_engine_logging: bool = False
448
449
450
    revision: str | None = ModelConfig.revision
    code_revision: str | None = ModelConfig.code_revision
    hf_token: bool | str | None = ModelConfig.hf_token
451
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
452
453
    tokenizer_revision: str | None = ModelConfig.tokenizer_revision
    quantization: QuantizationMethods | None = ModelConfig.quantization
454
    allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
455
    enforce_eager: bool = ModelConfig.enforce_eager
456
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
457
    limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
458
459
        MultiModalConfig, "limit_per_prompt"
    )
460
    enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
461
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
462
463
464
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
465
    mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
466
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
467
    mm_processor_cache_type: MMCacheType | None = (
468
        MultiModalConfig.mm_processor_cache_type
469
470
    )
    mm_shm_cache_max_object_size_mb: int = (
471
        MultiModalConfig.mm_shm_cache_max_object_size_mb
472
    )
473
    mm_encoder_only: bool = MultiModalConfig.mm_encoder_only
474
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
475
    mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
476
477
        MultiModalConfig.mm_encoder_attn_backend
    )
478
    io_processor_plugin: str | None = None
479
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
480
    video_pruning_rate: float = MultiModalConfig.video_pruning_rate
481
    # LoRA fields
482
    enable_lora: bool = False
483
484
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
485
    default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras
486
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
487
488
    max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
    lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
489
    enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
490

491
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
492
    num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
493
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
494
    ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
495

496
    enable_chunked_prefill: bool | None = None
497
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
498

499
    disable_hybrid_kv_cache_manager: bool | None = (
500
501
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
502

503
    structured_outputs_config: StructuredOutputsConfig = get_field(
504
505
        VllmConfig, "structured_outputs_config"
    )
506
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
507
    reasoning_parser_plugin: str | None = None
508

509
    logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
510

511
    speculative_config: dict[str, Any] | None = None
512

513
    show_hidden_metrics_for_version: str | None = (
514
        ObservabilityConfig.show_hidden_metrics_for_version
515
    )
516
517
    otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: list[DetailedTraceModules] | None = (
518
        ObservabilityConfig.collect_detailed_traces
519
    )
520
521
522
523
    kv_cache_metrics: bool = ObservabilityConfig.kv_cache_metrics
    kv_cache_metrics_sample: float = get_field(
        ObservabilityConfig, "kv_cache_metrics_sample"
    )
524
    cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
525
526
527
    enable_layerwise_nvtx_tracing: bool = (
        ObservabilityConfig.enable_layerwise_nvtx_tracing
    )
528
    enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
529
530
531
    enable_logging_iteration_details: bool = (
        ObservabilityConfig.enable_logging_iteration_details
    )
532
    enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
533
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
534
    scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
535

536
    pooler_config: PoolerConfig | None = ModelConfig.pooler_config
537
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
538
    attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
539
540
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
541

542
543
    profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")

544
545
    kv_transfer_config: KVTransferConfig | None = None
    kv_events_config: KVEventsConfig | None = None
546

547
    ec_transfer_config: ECTransferConfig | None = None
548

549
550
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
551
552
553
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
554
    model_impl: str = ModelConfig.model_impl
555
    override_attention_dtype: str = ModelConfig.override_attention_dtype
556
    attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
557

558
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
559
560
561
    kv_cache_dtype_skip_layers: list[str] = get_field(
        CacheConfig, "kv_cache_dtype_skip_layers"
    )
562
563
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
564
    mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
565
    mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
566

567
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
568

569
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
570
    pt_load_map_location: str = LoadConfig.pt_load_map_location
王敏's avatar
王敏 committed
571

572
    logits_processors: list[str | type[LogitsProcessor]] | None = (
573
574
        ModelConfig.logits_processors
    )
575
576
    """Custom logitproc types"""

577
    async_scheduling: bool | None = SchedulerConfig.async_scheduling
578

579
580
    stream_interval: int = SchedulerConfig.stream_interval

581
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
582
    optimization_level: OptimizationLevel = VllmConfig.optimization_level
583

584
    kv_offloading_size: float | None = CacheConfig.kv_offloading_size
585
    kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
586
    tokens_only: bool = False
587

588
    def __post_init__(self):
589
590
591
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
592
        if isinstance(self.compilation_config, dict):
593
            self.compilation_config = CompilationConfig(**self.compilation_config)
594
595
        if isinstance(self.attention_config, dict):
            self.attention_config = AttentionConfig(**self.attention_config)
596
        if isinstance(self.eplb_config, dict):
597
            self.eplb_config = EPLBConfig(**self.eplb_config)
598
        # Setup plugins
599
        from vllm.plugins import load_general_plugins
600

601
        load_general_plugins()
602
        # when use hf offline,replace model and tokenizer id to local model path
603
604
605
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
            if model_id is not self.model:
                logger.info(
                    "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                    model_id,
                    self.model,
                )
            if self.tokenizer is not None:
                tokenizer_id = self.tokenizer
                self.tokenizer = get_model_path(self.tokenizer, self.tokenizer_revision)
                if tokenizer_id is not self.tokenizer:
                    logger.info(
                        "HF_HUB_OFFLINE is True, replace tokenizer_id [%s] "
                        "to tokenizer_path [%s]",
                        tokenizer_id,
                        self.tokenizer,
                    )
622
623

    @staticmethod
624
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
625
        """Shared CLI arguments for vLLM engine."""
626

627
        # Model arguments
628
629
630
631
632
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
633
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
634
            model_group.add_argument("--model", **model_kwargs["model"])
635
636
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
637
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
638
639
640
641
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
642
643
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
644
645
646
647
648
649
650
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
651
        model_group.add_argument("--revision", **model_kwargs["revision"])
652
653
654
655
656
657
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
658
659
660
661
        model_group.add_argument(
            "--allow-deprecated-quantization",
            **model_kwargs["allow_deprecated_quantization"],
        )
662
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
663
664
665
666
        model_group.add_argument(
            "--enable-return-routed-experts",
            **model_kwargs["enable_return_routed_experts"],
        )
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
685
686
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
        )
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
709
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
710
711
712
713
714
715
716
717
718
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
        )
719

720
721
722
723
724
725
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
726
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
727
728
729
730
731
732
733
734
735
736
737
738
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
739

740
741
742
743
744
745
746
747
748
749
        # Attention arguments
        attention_kwargs = get_kwargs(AttentionConfig)
        attention_group = parser.add_argument_group(
            title="AttentionConfig",
            description=AttentionConfig.__doc__,
        )
        attention_group.add_argument(
            "--attention-backend", **attention_kwargs["backend"]
        )

750
751
752
753
754
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
755
        )
756
        structured_outputs_group.add_argument(
757
            "--reasoning-parser",
758
            # Choices need to be validated after parsing to include plugins
759
760
            **structured_outputs_kwargs["reasoning_parser"],
        )
761
762
763
764
        structured_outputs_group.add_argument(
            "--reasoning-parser-plugin",
            **structured_outputs_kwargs["reasoning_parser_plugin"],
        )
765

766
        # Parallel arguments
767
768
769
770
771
772
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
773
            "--distributed-executor-backend",
774
775
            **parallel_kwargs["distributed_executor_backend"],
        )
776
        parallel_group.add_argument(
777
778
779
780
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
781
782
783
784
        parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
        parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
        parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
        parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
785
        parallel_group.add_argument(
786
787
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
788
        parallel_group.add_argument(
789
790
791
792
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
793
        parallel_group.add_argument(
794
795
796
            "--dcp-kv-cache-interleave-size",
            **parallel_kwargs["dcp_kv_cache_interleave_size"],
        )
797
        parallel_group.add_argument(
798
799
800
            "--cp-kv-cache-interleave-size",
            **parallel_kwargs["cp_kv_cache_interleave_size"],
        )
801
        parallel_group.add_argument(
802
803
804
805
            "--prefill-context-parallel-size",
            "-pcp",
            **parallel_kwargs["prefill_context_parallel_size"],
        )
806
        parallel_group.add_argument(
807
808
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
809
        parallel_group.add_argument(
810
811
            "--data-parallel-rank",
            "-dpn",
812
            type=int,
813
814
815
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
816
        parallel_group.add_argument(
817
818
            "--data-parallel-start-rank",
            "-dpr",
819
            type=int,
820
821
            help="Starting data parallel rank for secondary nodes.",
        )
822
        parallel_group.add_argument(
823
824
            "--data-parallel-size-local",
            "-dpl",
825
            type=int,
826
827
            help="Number of data parallel replicas to run on this node.",
        )
828
        parallel_group.add_argument(
829
830
831
832
833
834
835
836
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
837
            type=int,
838
839
            help="Port for data parallel RPC communication.",
        )
840
        parallel_group.add_argument(
841
842
843
844
845
846
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
847
        parallel_group.add_argument(
848
849
850
851
852
853
854
855
            "--data-parallel-hybrid-lb",
            "-dph",
            **parallel_kwargs["data_parallel_hybrid_lb"],
        )
        parallel_group.add_argument(
            "--data-parallel-external-lb",
            "-dpe",
            **parallel_kwargs["data_parallel_external_lb"],
856
857
        )
        parallel_group.add_argument(
858
859
860
            "--enable-expert-parallel",
            "-ep",
            **parallel_kwargs["enable_expert_parallel"],
861
        )
862
863
864
        parallel_group.add_argument(
            "--all2all-backend", **parallel_kwargs["all2all_backend"]
        )
865
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
866
867
868
869
        parallel_group.add_argument(
            "--ubatch-size",
            **parallel_kwargs["ubatch_size"],
        )
870
871
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
872
873
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
874
875
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
876
877
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
878
879
880
881
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
882
883
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
884
885
        parallel_group.add_argument(
            "--expert-placement-strategy",
886
887
            **parallel_kwargs["expert_placement_strategy"],
        )
888

889
        parallel_group.add_argument(
890
            "--max-parallel-loading-workers",
891
892
            **parallel_kwargs["max_parallel_loading_workers"],
        )
893
        parallel_group.add_argument(
894
895
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
896
        parallel_group.add_argument(
897
            "--disable-custom-all-reduce",
898
899
900
901
902
903
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
904

905
906
907
908
909
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
910
        )
911
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
912
913
914
915
916
917
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
918
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
919
920
921
922
923
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
924
925
926
927
928
            "--enable-prefix-caching",
            **{
                **cache_kwargs["enable_prefix_caching"],
                "default": None,
            },
929
930
931
932
933
934
935
936
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
937
938
939
        cache_group.add_argument(
            "--kv-cache-dtype-skip-layers", **cache_kwargs["kv_cache_dtype_skip_layers"]
        )
940
941
942
943
944
945
946
947
948
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
949
950
951
        cache_group.add_argument(
            "--mamba-block-size", **cache_kwargs["mamba_block_size"]
        )
952
953
954
        cache_group.add_argument(
            "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
        )
955
956
957
958
959
960
        cache_group.add_argument(
            "--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
        )
        cache_group.add_argument(
            "--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
        )
961

962
        # Multimodal related configs
963
964
965
966
967
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
968
        multimodal_group.add_argument(
969
970
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
971
972
973
        multimodal_group.add_argument(
            "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
        )
974
975
976
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
977
        multimodal_group.add_argument(
978
979
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
980
        multimodal_group.add_argument(
981
982
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
983
        multimodal_group.add_argument(
984
985
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
986
987
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
988
989
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
990
991
992
        multimodal_group.add_argument(
            "--mm-encoder-only", **multimodal_kwargs["mm_encoder_only"]
        )
993
        multimodal_group.add_argument(
994
995
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
996
        multimodal_group.add_argument(
997
998
999
            "--mm-encoder-attn-backend",
            **multimodal_kwargs["mm_encoder_attn_backend"],
        )
1000
1001
1002
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
1003
        multimodal_group.add_argument(
1004
1005
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
1006

1007
        multimodal_group.add_argument(
1008
1009
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
1010

1011
        # LoRA related configs
1012
1013
1014
1015
1016
1017
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
1018
            "--enable-lora",
1019
            action=argparse.BooleanOptionalAction,
1020
1021
            help="If True, enable handling of LoRA adapters.",
        )
1022
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
1023
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
1024
        lora_group.add_argument(
1025
            "--lora-dtype",
1026
1027
            **lora_kwargs["lora_dtype"],
        )
1028
1029
1030
1031
        lora_group.add_argument(
            "--enable-tower-connector-lora",
            **lora_kwargs["enable_tower_connector_lora"],
        )
1032
1033
1034
1035
1036
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
1037

1038
1039
1040
1041
1042
1043
1044
1045
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
1046
1047
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
1048
        observability_group.add_argument(
1049
1050
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
1051
1052
1053
1054
1055
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
1056
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
1057
1058
1059
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
1060
1061
            **observability_kwargs["collect_detailed_traces"],
        )
1062
1063
1064
1065
1066
1067
1068
        observability_group.add_argument(
            "--kv-cache-metrics", **observability_kwargs["kv_cache_metrics"]
        )
        observability_group.add_argument(
            "--kv-cache-metrics-sample",
            **observability_kwargs["kv_cache_metrics_sample"],
        )
1069
1070
1071
1072
        observability_group.add_argument(
            "--cudagraph-metrics",
            **observability_kwargs["cudagraph_metrics"],
        )
1073
1074
1075
1076
        observability_group.add_argument(
            "--enable-layerwise-nvtx-tracing",
            **observability_kwargs["enable_layerwise_nvtx_tracing"],
        )
1077
1078
1079
1080
        observability_group.add_argument(
            "--enable-mfu-metrics",
            **observability_kwargs["enable_mfu_metrics"],
        )
1081
1082
1083
1084
        observability_group.add_argument(
            "--enable-logging-iteration-details",
            **observability_kwargs["enable_logging_iteration_details"],
        )
1085

1086
1087
1088
1089
1090
1091
1092
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
1093
            "--max-num-batched-tokens",
1094
1095
1096
1097
            **{
                **scheduler_kwargs["max_num_batched_tokens"],
                "default": None,
            },
1098
        )
1099
        scheduler_group.add_argument(
1100
1101
1102
1103
1104
            "--max-num-seqs",
            **{
                **scheduler_kwargs["max_num_seqs"],
                "default": None,
            },
1105
1106
1107
1108
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
1109
1110
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
1111
1112
            **scheduler_kwargs["max_long_partial_prefills"],
        )
1113
1114
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
1115
1116
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
1117
1118
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
1119
        scheduler_group.add_argument(
1120
1121
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
1122
        scheduler_group.add_argument(
1123
            "--enable-chunked-prefill",
1124
1125
1126
1127
            **{
                **scheduler_kwargs["enable_chunked_prefill"],
                "default": None,
            },
1128
        )
1129
        scheduler_group.add_argument(
1130
1131
1132
1133
1134
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1135
1136
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1137
1138
1139
1140
1141
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1142
1143
1144
        scheduler_group.add_argument(
            "--stream-interval", **scheduler_kwargs["stream_interval"]
        )
1145

1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
        # Compilation arguments
        compilation_kwargs = get_kwargs(CompilationConfig)
        compilation_group = parser.add_argument_group(
            title="CompilationConfig",
            description=CompilationConfig.__doc__,
        )
        compilation_group.add_argument(
            "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
        )
        compilation_group.add_argument(
            "--max-cudagraph-capture-size",
            **compilation_kwargs["max_cudagraph_capture_size"],
        )
1159
1160

        # vLLM arguments
1161
        vllm_kwargs = get_kwargs(VllmConfig)
1162
1163
1164
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
1165
        )
1166
1167
1168
1169
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1170
1171
1172
1173
1174
1175
1176
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
1177
1178
1179
        vllm_group.add_argument(
            "--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
        )
1180
        vllm_group.add_argument(
1181
            "--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
1182
        )
1183
1184
1185
        vllm_group.add_argument(
            "--attention-config", "-ac", **vllm_kwargs["attention_config"]
        )
1186
1187
1188
1189
1190
1191
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1192
        vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
1193
1194
1195
        vllm_group.add_argument(
            "--optimization-level", **vllm_kwargs["optimization_level"]
        )
1196

1197
        # Other arguments
1198
1199
1200
1201
1202
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1203

1204
1205
1206
1207
1208
1209
        parser.add_argument(
            "--aggregate-engine-logging",
            action="store_true",
            help="Log aggregate rather than per-engine statistics "
            "when using data parallelism.",
        )
1210
        return parser
1211
1212

    @classmethod
1213
    def from_cli_args(cls, args: argparse.Namespace):
1214
1215
1216
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1217
1218
1219
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1220
        return engine_args
1221

1222
    def create_model_config(self) -> ModelConfig:
1223
1224
        # gguf file needs a specific model loader
        if is_gguf(self.model):
1225
1226
            self.quantization = self.load_format = "gguf"

1227
1228
1229
1230
1231
1232
1233
        if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
            logger.warning(
                "The global random seed is set to %d. Since "
                "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
                "affect the random state of the Python process that "
                "launched vLLM.",
                self.seed,
1234
            )
1235

1236
        return ModelConfig(
1237
            model=self.model,
1238
            model_weights=self.model_weights,
1239
            hf_config_path=self.hf_config_path,
1240
1241
            runner=self.runner,
            convert=self.convert,
1242
            tokenizer=self.tokenizer,
1243
1244
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1245
            allowed_local_media_path=self.allowed_local_media_path,
1246
            allowed_media_domains=self.allowed_media_domains,
1247
1248
1249
1250
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
1251
            hf_token=self.hf_token,
1252
            hf_overrides=self.hf_overrides,
1253
1254
1255
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
1256
            allow_deprecated_quantization=self.allow_deprecated_quantization,
1257
            enforce_eager=self.enforce_eager,
1258
            enable_return_routed_experts=self.enable_return_routed_experts,
1259
            max_logprobs=self.max_logprobs,
1260
            logprobs_mode=self.logprobs_mode,
1261
            disable_sliding_window=self.disable_sliding_window,
1262
            disable_cascade_attn=self.disable_cascade_attn,
1263
            skip_tokenizer_init=self.skip_tokenizer_init,
1264
            enable_prompt_embeds=self.enable_prompt_embeds,
1265
            served_model_name=self.served_model_name,
1266
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1267
            enable_mm_embeds=self.enable_mm_embeds,
1268
            interleave_mm_strings=self.interleave_mm_strings,
1269
            media_io_kwargs=self.media_io_kwargs,
1270
            skip_mm_profiling=self.skip_mm_profiling,
1271
            config_format=self.config_format,
1272
            mm_processor_kwargs=self.mm_processor_kwargs,
1273
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1274
            mm_processor_cache_type=self.mm_processor_cache_type,
1275
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1276
            mm_encoder_only=self.mm_encoder_only,
1277
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1278
            mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1279
            pooler_config=self.pooler_config,
1280
            logits_processor_pattern=self.logits_processor_pattern,
1281
            generation_config=self.generation_config,
1282
            override_generation_config=self.override_generation_config,
1283
            enable_sleep_mode=self.enable_sleep_mode,
1284
            model_impl=self.model_impl,
1285
            override_attention_dtype=self.override_attention_dtype,
1286
            logits_processors=self.logits_processors,
1287
            video_pruning_rate=self.video_pruning_rate,
1288
            io_processor_plugin=self.io_processor_plugin,
1289
            enable_chunked_prefill=self.enable_chunked_prefill,
1290
        )
1291

1292
    def validate_tensorizer_args(self):
1293
1294
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1295
1296
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1297
1298
1299
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1300

1301
    def create_load_config(self) -> LoadConfig:
1302
1303
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1304

1305
1306
1307
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1308
1309
                    self.model_loader_extra_config.to_serializable()
                )
1310
            self.model_loader_extra_config["tensorizer_config"] = {}
1311
1312
1313
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1314
            self.validate_tensorizer_args()
1315

1316
1317
1318
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1319
            safetensors_load_strategy=self.safetensors_load_strategy,
1320
1321
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1322
            use_tqdm_on_load=self.use_tqdm_on_load,
1323
            pt_load_map_location=self.pt_load_map_location,
1324
1325
        )

1326
1327
1328
1329
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
1330
    ) -> SpeculativeConfig | None:
1331
1332
1333
1334
1335
1336
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1337
        dictionary from the engine.
1338
1339
        """
        if self.speculative_config is None:
1340
            return None
1341

1342
1343
1344
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1345
1346
1347
1348
1349
1350
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
            }
        )
1351
        return SpeculativeConfig(**self.speculative_config)
1352

1353
1354
    def create_engine_config(
        self,
1355
        usage_context: UsageContext | None = None,
1356
        headless: bool = False,
1357
1358
1359
    ) -> VllmConfig:
        """
        Create the VllmConfig.
1360

1361
        NOTE: If VllmConfig is incompatible, we raise an error.
1362
        """
1363
        current_platform.pre_register_and_update()
1364

1365
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1366

1367
1368
        # Check if the model is a speculator and override model/tokenizer/config
        # BEFORE creating ModelConfig, so the config is created with the target model
1369
1370
1371
1372
        # Skip speculator detection for cloud storage models (eg: S3, GCS) since
        # HuggingFace cannot load configs directly from S3 URLs. S3 models can still
        # use speculators with explicit --speculative-config.
        if not is_cloud_storage(self.model):
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
            (self.model, self.tokenizer, self.speculative_config) = (
                maybe_override_with_speculators(
                    model=self.model,
                    tokenizer=self.tokenizer,
                    revision=self.revision,
                    trust_remote_code=self.trust_remote_code,
                    vllm_speculative_config=self.speculative_config,
                )
            )

1383
        model_config = self.create_model_config()
1384
        self.model = model_config.model
1385
        self.model_weights = model_config.model_weights
1386
        self.tokenizer = model_config.tokenizer
1387

1388
        self._check_feature_supported(model_config)
1389
1390
1391
1392
        self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
        self._set_default_max_num_seqs_and_batched_tokens_args(
            usage_context, model_config
        )
1393

1394
        sliding_window: int | None = None
1395
1396
1397
1398
1399
1400
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1401
1402
1403
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1404
        # reuses the GPUs of TP group, and split one TP group into
1405
        # tp_size//dcp_size DCP groups.
1406
        assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, (
1407
1408
1409
1410
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1411
1412
1413
1414
1415
        # Resolve "auto" kv_cache_dtype to actual value from model config
        resolved_cache_dtype = resolve_kv_cache_dtype_string(
            self.kv_cache_dtype, model_config
        )

1416
        cache_config = CacheConfig(
1417
            block_size=self.block_size,
1418
            gpu_memory_utilization=self.gpu_memory_utilization,
1419
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1420
            swap_space=self.swap_space,
1421
            cache_dtype=resolved_cache_dtype,
1422
            is_attention_free=model_config.is_attention_free,
1423
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1424
            sliding_window=sliding_window,
1425
            enable_prefix_caching=self.enable_prefix_caching,
1426
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1427
            cpu_offload_gb=self.cpu_offload_gb,
1428
            calculate_kv_scales=self.calculate_kv_scales,
1429
            kv_cache_dtype_skip_layers=self.kv_cache_dtype_skip_layers,
1430
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1431
1432
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1433
            mamba_block_size=self.mamba_block_size,
1434
            mamba_cache_mode=self.mamba_cache_mode,
1435
1436
            kv_offloading_size=self.kv_offloading_size,
            kv_offloading_backend=self.kv_offloading_backend,
1437
        )
1438

1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
        # TurboQuant: auto-skip first/last 2 layers (boundary protection).
        # These layers are most sensitive to quantization error.
        # Users can add extra layers via --kv-cache-dtype-skip-layers.
        # Disabled for hybrid models (attn+mamba) — mixed page sizes break
        # the required page size unification.
        if (
            resolved_cache_dtype.startswith("turboquant_")
            and not model_config.is_hybrid
        ):
            from vllm.model_executor.layers.quantization.turboquant.config import (
                TurboQuantConfig,
            )

            num_layers = model_config.hf_text_config.num_hidden_layers
            boundary = TurboQuantConfig.get_boundary_skip_layers(num_layers)
            existing = set(cache_config.kv_cache_dtype_skip_layers)
            merged = sorted(existing | set(boundary), key=lambda x: int(x))
            cache_config.kv_cache_dtype_skip_layers = merged
            logger.info(
                "TQ: skipping layers %s for boundary protection (num_layers=%d)",
                merged,
                num_layers,
            )

1463
1464
1465
1466
1467
1468
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1469

1470
            ray_runtime_env = ray.get_runtime_context().runtime_env
1471
1472
1473
1474
1475
1476
1477
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1478

1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1490
        assert not headless or not self.data_parallel_hybrid_lb, (
1491
1492
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
        assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
            "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
        )
        assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
            "nnodes > 1 is only supported with data_parallel_backend=mp"
        )
        inferred_data_parallel_rank = 0
        if self.nnodes > 1:
            world_size = (
                self.data_parallel_size
                * self.pipeline_parallel_size
                * self.tensor_parallel_size
            )
            world_size_within_dp = (
                self.pipeline_parallel_size * self.tensor_parallel_size
            )
            local_world_size = world_size // self.nnodes
            assert world_size % self.nnodes == 0, (
                f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
            )
            assert self.node_rank < self.nnodes, (
                f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
            )
            inferred_data_parallel_rank = (
                self.node_rank * local_world_size
            ) // world_size_within_dp
            if self.data_parallel_size > 1 and self.data_parallel_external_lb:
                self.data_parallel_rank = inferred_data_parallel_rank
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d for external lb",
                    self.data_parallel_rank,
                    self.node_rank,
                )
            elif self.data_parallel_size_local is None:
                # Infer data parallel size local for internal dplb:
                self.data_parallel_size_local = max(
                    local_world_size // world_size_within_dp, 1
                )
        data_parallel_external_lb = (
            self.data_parallel_external_lb or self.data_parallel_rank is not None
        )
1534
        # Local DP rank = 1, use pure-external LB.
1535
        if data_parallel_external_lb:
1536
            assert self.data_parallel_rank is not None, (
1537
                "data_parallel_rank or node_rank must be specified if "
1538
1539
                "data_parallel_external_lb is enable."
            )
1540
            assert self.data_parallel_size_local in (1, None), (
1541
1542
                "data_parallel_size_local must be 1 or None when data_parallel_rank "
                "is set"
1543
            )
1544
            data_parallel_size_local = 1
1545
1546
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1547
1548
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1549
1550
1551
1552
1553
1554
1555

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
1556
1557
1558
1559
1560
                logger.warning(
                    "data_parallel_hybrid_lb is not eligible when "
                    "data_parallel_size_local = 1, autoswitch to "
                    "data_parallel_external_lb."
                )
1561
1562
1563
1564
1565
1566
1567
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

1568
1569
1570
1571
1572
1573
1574
1575
1576
            self.data_parallel_rank = (
                self.data_parallel_start_rank or inferred_data_parallel_rank
            )
            if self.nnodes > 1:
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d",
                    self.data_parallel_rank,
                    self.node_rank,
                )
1577
        else:
1578
            assert not self.data_parallel_hybrid_lb, (
1579
1580
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1581

1582
1583
1584
1585
1586
1587
1588
1589
1590
            if self.data_parallel_backend == "ray" and (
                envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
            ):
                # Data parallel size defaults to 1 if DP ranks are spanning
                # multiple nodes
                data_parallel_size_local = 1
            else:
                # Otherwise local DP size defaults to global DP size if not set
                data_parallel_size_local = self.data_parallel_size
1591
1592
1593

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1594
1595
1596
1597
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1598
1599
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1600
1601
1602
1603
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1604
1605
                    self.data_parallel_backend,
                )
1606
1607
1608
                data_parallel_address = (
                    self.master_addr or ParallelConfig.data_parallel_master_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1609
1610
        else:
            data_parallel_address = self.data_parallel_address
1611
1612
1613

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1614
        data_parallel_rpc_port = (
1615
            self.data_parallel_rpc_port
1616
1617
1618
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1619

1620
1621
1622
1623
        if self.tokens_only and not model_config.skip_tokenizer_init:
            model_config.skip_tokenizer_init = True
            logger.info("Skipping tokenizer initialization for tokens-only mode.")

1624
        parallel_config = ParallelConfig(
1625
1626
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1627
            prefill_context_parallel_size=self.prefill_context_parallel_size,
1628
            data_parallel_size=self.data_parallel_size,
1629
1630
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1631
            data_parallel_size_local=data_parallel_size_local,
1632
1633
1634
1635
            master_addr=self.master_addr,
            master_port=self.master_port,
            nnodes=self.nnodes,
            node_rank=self.node_rank,
1636
1637
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1638
            data_parallel_backend=self.data_parallel_backend,
1639
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1640
            is_moe_model=model_config.is_moe,
1641
            enable_expert_parallel=self.enable_expert_parallel,
1642
            all2all_backend=self.all2all_backend,
1643
            enable_dbo=self.enable_dbo,
1644
            ubatch_size=self.ubatch_size,
1645
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1646
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1647
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1648
            enable_eplb=self.enable_eplb,
1649
            eplb_config=self.eplb_config,
1650
            expert_placement_strategy=self.expert_placement_strategy,
1651
1652
1653
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1654
            ray_runtime_env=ray_runtime_env,
1655
            placement_group=placement_group,
1656
1657
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1658
            worker_extension_cls=self.worker_extension_cls,
1659
            decode_context_parallel_size=self.decode_context_parallel_size,
1660
            dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
1661
            cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
1662
1663
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1664
        )
1665

1666
        speculative_config = self.create_speculative_config(
1667
1668
1669
            target_model_config=model_config,
            target_parallel_config=parallel_config,
        )
1670
        
1671
        scheduler_config = SchedulerConfig(
1672
            runner_type=model_config.runner_type,
1673
1674
1675
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1676
            enable_chunked_prefill=self.enable_chunked_prefill,
1677
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1678
            is_multimodal_model=model_config.is_multimodal_model,
1679
            is_encoder_decoder=model_config.is_encoder_decoder,
1680
            policy=self.scheduling_policy,
1681
            scheduler_cls=self.scheduler_cls,
1682
1683
1684
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1685
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1686
            async_scheduling=self.async_scheduling,
1687
            stream_interval=self.stream_interval,
1688
        )
1689

1690
1691
1692
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_dtype=self.lora_dtype,
1703
                enable_tower_connector_lora=self.enable_tower_connector_lora,
1704
1705
1706
1707
1708
1709
1710
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1711

1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
        if (
            lora_config is not None
            and speculative_config is not None
            and scheduler_config.max_num_batched_tokens
            < (
                scheduler_config.max_num_seqs
                * (speculative_config.num_speculative_tokens + 1)
            )
        ):
            raise ValueError(
                "Consider increasing max_num_batched_tokens or "
                "decreasing num_speculative_tokens"
            )
1725

1726
1727
1728
1729
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1730
1731
1732
1733
1734
1735
1736
1737
        # Attention config overrides
        attention_config = copy.deepcopy(self.attention_config)
        if self.attention_backend is not None:
            if attention_config.backend is not None:
                raise ValueError(
                    "attention_backend and attention_config.backend "
                    "are mutually exclusive"
                )
1738
1739
1740
1741
1742
1743
1744
            # Convert string to enum if needed (CLI parsing returns a string)
            if isinstance(self.attention_backend, str):
                attention_config.backend = AttentionBackendEnum[
                    self.attention_backend.upper()
                ]
            else:
                attention_config.backend = self.attention_backend
1745

1746
        load_config = self.create_load_config()
1747

1748
1749
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1750
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1751

1752
1753
1754
1755
        if self.reasoning_parser_plugin:
            self.structured_outputs_config.reasoning_parser_plugin = (
                self.reasoning_parser_plugin
            )
1756

1757
        observability_config = ObservabilityConfig(
1758
            show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
1759
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1760
            collect_detailed_traces=self.collect_detailed_traces,
1761
1762
            kv_cache_metrics=self.kv_cache_metrics,
            kv_cache_metrics_sample=self.kv_cache_metrics_sample,
1763
            cudagraph_metrics=self.cudagraph_metrics,
1764
            enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
1765
            enable_mfu_metrics=self.enable_mfu_metrics,
1766
            enable_mm_processor_stats=self.enable_mm_processor_stats,
1767
            enable_logging_iteration_details=self.enable_logging_iteration_details,
1768
        )
1769

1770
        # Compilation config overrides
1771
        compilation_config = copy.deepcopy(self.compilation_config)
1772
        if self.cudagraph_capture_sizes is not None:
1773
            if compilation_config.cudagraph_capture_sizes is not None:
1774
1775
1776
1777
                raise ValueError(
                    "cudagraph_capture_sizes and compilation_config."
                    "cudagraph_capture_sizes are mutually exclusive"
                )
1778
            compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes
1779
        if self.max_cudagraph_capture_size is not None:
1780
            if compilation_config.max_cudagraph_capture_size is not None:
1781
1782
1783
1784
                raise ValueError(
                    "max_cudagraph_capture_size and compilation_config."
                    "max_cudagraph_capture_size are mutually exclusive"
                )
1785
            compilation_config.max_cudagraph_capture_size = (
1786
1787
                self.max_cudagraph_capture_size
            )
1788
        config = VllmConfig(
1789
1790
1791
1792
1793
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
1794
1795
            load_config=load_config,
            attention_config=attention_config,
1796
1797
            lora_config=lora_config,
            speculative_config=speculative_config,
1798
            structured_outputs_config=self.structured_outputs_config,
1799
            observability_config=observability_config,
1800
            compilation_config=compilation_config,
1801
            kv_transfer_config=self.kv_transfer_config,
1802
            kv_events_config=self.kv_events_config,
1803
            ec_transfer_config=self.ec_transfer_config,
1804
            profiler_config=self.profiler_config,
1805
            additional_config=self.additional_config,
1806
            optimization_level=self.optimization_level,
1807
        )
1808

1809
1810
        return config

1811
1812
    def _check_feature_supported(self, model_config: ModelConfig):
        """Raise an error if the feature is not supported."""
1813
        if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
1814
            _raise_unsupported_error(feature_name="--logits-processor-pattern")
1815
1816

        # No Concurrent Partial Prefills so far.
1817
1818
1819
1820
1821
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
1822
            _raise_unsupported_error(feature_name="Concurrent Partial Prefill")
1823

1824
        if self.pipeline_parallel_size > 1:
1825
1826
1827
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1828
            if not supports_pp and self.distributed_executor_backend not in (
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
1839
                _raise_unsupported_error(feature_name=name)
1840

1841
1842
1843
1844
1845
1846
    @classmethod
    def get_batch_defaults(
        cls,
        world_size: int,
    ) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
        from vllm.usage.usage_lib import UsageContext
1847

1848
1849
        default_max_num_batched_tokens: dict[UsageContext | None, int]
        default_max_num_seqs: dict[UsageContext | None, int]
1850

1851
1852
        # When no user override, set the default values based on the usage
        # context.
1853
        # Use different default values for different hardware.
1854
1855
1856
1857
1858
1859
1860

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1861
            device_memory = current_platform.get_device_total_memory()
1862
            device_name = current_platform.get_device_name().lower()
1863
1864
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1865
            device_memory = 0
1866
            device_name = ""
1867

1868
1869
1870
1871
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1872
            # For GPUs like H100 and MI300x, use larger default values.
1873
1874
1875
1876
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1877
1878
1879
1880
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1881
1882
1883
1884
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
1885
                UsageContext.OPENAI_API_SERVER: 10240,
1886
            }
1887
1888
1889
1890
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1891

1892
1893
        # tpu specific default values.
        if current_platform.is_tpu():
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
            chip_name = current_platform.get_device_name()

            if chip_name == "V6E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 2048,
                    UsageContext.OPENAI_API_SERVER: 1024,
                }
            elif chip_name == "V5E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 1024,
                    UsageContext.OPENAI_API_SERVER: 512,
                }
            elif chip_name == "V5P":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 512,
                    UsageContext.OPENAI_API_SERVER: 256,
1910
1911
                }

1912
1913
1914
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
1915
1916
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1917
1918
            }
            default_max_num_seqs = {
1919
1920
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1921
1922
            }

1923
1924
        return default_max_num_batched_tokens, default_max_num_seqs

1925
1926
    def _set_default_chunked_prefill_and_prefix_caching_args(
        self, model_config: ModelConfig
1927
    ) -> None:
1928
1929
        default_chunked_prefill = model_config.is_chunked_prefill_supported
        default_prefix_caching = model_config.is_prefix_caching_supported
1930
1931
1932
1933
1934
1935
1936
1937

        if self.enable_chunked_prefill is None:
            self.enable_chunked_prefill = default_chunked_prefill

            logger.debug(
                "%s chunked prefill by default",
                "Enabling" if default_chunked_prefill else "Disabling",
            )
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
        elif (
            model_config.runner_type == "generate"
            and not self.enable_chunked_prefill
            and default_chunked_prefill
        ):
            logger.warning_once(
                "This model does not officially support disabling chunked prefill. "
                "Disabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
                scope="local",
            )
1949
1950
1951
1952
        elif (
            model_config.runner_type == "pooling"
            and self.enable_chunked_prefill
            and not default_chunked_prefill
1953
        ):
1954
            logger.warning_once(
1955
1956
1957
                "This model does not officially support chunked prefill. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
1958
                scope="local",
1959
1960
1961
1962
1963
            )

        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = default_prefix_caching

1964
            logger.debug(
1965
1966
1967
1968
1969
1970
1971
1972
                "%s prefix caching by default",
                "Enabling" if default_prefix_caching else "Disabling",
            )
        elif (
            model_config.runner_type == "pooling"
            and self.enable_prefix_caching
            and not default_prefix_caching
        ):
1973
            logger.warning_once(
1974
1975
1976
                "This model does not officially support prefix caching. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
1977
                scope="local",
1978
1979
            )

1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
        # Disable chunked prefill and prefix caching for:
        # POWER (ppc64le)/s390x/RISCV CPUs in V1
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.S390X,
            CpuArchEnum.RISCV,
        ):
            logger.info(
                "Chunked prefill is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
            self.enable_chunked_prefill = False
            logger.info(
                "Prefix caching is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
            self.enable_prefix_caching = False

    def _set_default_max_num_seqs_and_batched_tokens_args(
2001
2002
2003
        self,
        usage_context: UsageContext | None,
        model_config: ModelConfig,
2004
    ):
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
        world_size = self.pipeline_parallel_size * self.tensor_parallel_size
        (
            default_max_num_batched_tokens,
            default_max_num_seqs,
        ) = self.get_batch_defaults(world_size)

        orig_max_num_batched_tokens = self.max_num_batched_tokens
        orig_max_num_seqs = self.max_num_seqs

        if self.max_num_batched_tokens is None:
            self.max_num_batched_tokens = default_max_num_batched_tokens.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )

        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
            )

        if orig_max_num_batched_tokens is None:
            if not self.enable_chunked_prefill:
                # If max_model_len is too short, use the default for higher throughput.
                self.max_num_batched_tokens = max(
                    model_config.max_model_len,
                    self.max_num_batched_tokens,
                )

            # When using default settings,
            # Ensure max_num_batched_tokens does not exceed model limit.
            # Some models (e.g., Whisper) have embeddings tied to max length.
            self.max_num_batched_tokens = min(
                self.max_num_seqs * model_config.max_model_len,
2039
2040
                self.max_num_batched_tokens,
            )
2041

2042
            logger.debug(
2043
2044
2045
                "Defaulting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens,
                usage_context.value if usage_context else None,
2046
            )
2047

2048
2049
2050
        if orig_max_num_seqs is None:
            assert self.max_num_batched_tokens is not None  # For type checking
            self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
2051

2052
            logger.debug(
2053
                "Defaulting max_num_seqs to %d for %s usage context.",
2054
                self.max_num_seqs,
2055
                usage_context.value if usage_context else None,
2056
            )
2057

2058

2059
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
2060
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
2061
    """Arguments for asynchronous vLLM engine."""
2062
2063

    enable_log_requests: bool = False
2064
2065

    @staticmethod
2066
2067
2068
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
2069
        # Initialize plugin to update the parser, for example, The plugin may
2070
        # add a new kind of quantization method to --quantization argument or
2071
2072
        # a new device to --device argument.
        load_general_plugins()
2073
2074
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
2088
        current_platform.pre_register_and_update(parser)
2089
        return parser
2090
2091


2092
2093
2094
2095
2096
2097
def _raise_unsupported_error(feature_name: str):
    msg = (
        f"{feature_name} is not supported. We recommend to "
        f"remove {feature_name} from your config."
    )
    raise NotImplementedError(msg)
2098
2099


2100
def human_readable_int(value: str) -> int:
2101
2102
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
2103

2104
2105
2106
2107
2108
2109
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
2110

2111
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
2112
2113
    if match:
        decimal_multiplier = {
2114
2115
2116
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
2117
            "t": 10**12,
2118
2119
        }
        binary_multiplier = {
2120
2121
2122
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
2123
            "T": 2**40,
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
2136
2137
2138
2139
2140
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
2141
2142
2143

    # Regular plain number.
    return int(value)
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162


def human_readable_int_or_auto(value: str) -> int:
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    Also accepts -1 or 'auto' as a special value for auto-detection.

    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    - '-1' or 'auto' -> -1 (special value for auto-detection)
    """
    value = value.strip()

    if value == "-1" or value.lower() == "auto":
        return -1

    return human_readable_int(value)