arg_utils.py 86.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from collections.abc import Callable
11
from dataclasses import MISSING, dataclass, fields, is_dataclass
12
from itertools import permutations
13
from types import UnionType
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Literal,
19
    TypeAlias,
20
21
22
23
24
25
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
26

27
import huggingface_hub
28
import regex as re
29
import torch
30
from pydantic import TypeAdapter, ValidationError
31
from pydantic.fields import FieldInfo
32
from typing_extensions import TypeIs
33

34
import vllm.envs as envs
35
from vllm.attention.backends.registry import AttentionBackendEnum
36
from vllm.config import (
37
    AttentionConfig,
38
39
40
41
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
42
    ECTransferConfig,
43
44
45
46
47
48
    EPLBConfig,
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
    ModelConfig,
49
    MultiModalConfig,
50
51
52
    ObservabilityConfig,
    ParallelConfig,
    PoolerConfig,
53
    ProfilerConfig,
54
55
56
57
58
59
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
    VllmConfig,
    get_attr_docs,
)
60
61
62
63
64
65
66
from vllm.config.cache import (
    BlockSize,
    CacheDType,
    KVOffloadingBackend,
    MambaDType,
    PrefixCachingHashAlgo,
)
67
68
69
70
71
72
73
from vllm.config.device import Device
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
74
    TokenizerMode,
75
76
77
78
79
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
80
from vllm.config.utils import get_field
81
from vllm.config.vllm import OptimizationLevel
82
from vllm.logger import init_logger, suppress_logging
83
from vllm.platforms import CpuArchEnum, current_platform
84
from vllm.plugins import load_general_plugins
85
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
86
87
88
89
from vllm.transformers_utils.config import (
    is_interleaved,
    maybe_override_with_speculators,
)
90
from vllm.transformers_utils.gguf_utils import is_gguf
91
from vllm.transformers_utils.repo_utils import get_model_path
92
from vllm.transformers_utils.utils import is_cloud_storage
93
from vllm.utils.argparse_utils import FlexibleArgumentParser
94
from vllm.utils.mem_constants import GiB_bytes
95
from vllm.utils.network_utils import get_ip
96
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
97
from vllm.v1.sample.logits_processor import LogitsProcessor
98

99
100
if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
101
    from vllm.model_executor.model_loader import LoadFormats
102
    from vllm.usage.usage_lib import UsageContext
103
    from vllm.v1.executor import Executor
104
else:
105
    Executor = Any
106
    QuantizationMethods = Any
107
    LoadFormats = Any
108
109
    UsageContext = Any

110

111
112
logger = init_logger(__name__)

113
114
# object is used to allow for special typing forms
T = TypeVar("T")
115
116
TypeHint: TypeAlias = type[Any] | object
TypeHintT: TypeAlias = type[T] | object
117

118

119
120
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
121
122
123
124
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
125
126
                f"Value {val} cannot be converted to {return_type}."
            ) from e
127

128
129
130
    return _parse_type


131
132
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]:
    def _optional_type(val: str) -> T | None:
133
134
135
136
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

137
    return _optional_type
138
139


140
def union_dict_and_str(val: str) -> str | dict[str, str] | None:
141
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
142
        return str(val)
143
    return optional_type(json.loads)(val)
144
145


146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


161
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
162
163
164
165
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
166
    type_hint = get_type(type_hints, Literal)
167
168
169
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
170
        raise ValueError(
171
            "All options must be of the same type. "
172
173
            f"Got {options} with types {[type(c) for c in options]}"
        )
174
175
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
176
177


178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]:
    type_hint = get_type(type_hints, type)
    types = get_args(type_hint)
    elem_type = types[0]

    # Handle Ellipsis
    assert all(t is elem_type for t in types if t is not Ellipsis), (
        f"All non-Ellipsis elements must be of the same type. Got {types}."
    )

    # Handle Union types
    if get_origin(elem_type) in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
        assert str in get_args(elem_type), (
            "If element can have multiple types, one must be 'str' "
            f"(i.e. 'list[int | str]'). Got {elem_type}."
        )
        elem_type = str

    return {
        "type": elem_type,
        "nargs": "+" if type is not tuple or Ellipsis in types else len(types),
    }


203
204
205
206
207
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


208
209
210
211
212
213
214
215
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
216
217
    elif origin in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
218
219
220
221
222
223
224
225
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


226
227
228
229
def is_online_quantization(quantization: Any) -> bool:
    return quantization in ["inc"]


230
NEEDS_HELP = (
231
232
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
233
234
235
236
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


237
@functools.lru_cache(maxsize=30)
238
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
239
240
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
241
242
    kwargs = {}
    for field in fields(cls):
243
        # Get the set of possible types for the field
244
        type_hints: set[TypeHint] = get_type_hints(field.type)
245
246
247
248
249

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

250
        # Get the default value of the field
251
252
        if field.default is not MISSING:
            default = field.default
253
254
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
255
256
257
258
259
260
261
                if default.default_factory is None:
                    default = default.default
                else:
                    # VllmConfig's Fields have default_factory set to config classes.
                    # These could emit logs on init, which would be confusing.
                    with suppress_logging():
                        default = default.default_factory()
262
        elif field.default_factory is not MISSING:
263
            default = field.default_factory()
264
265
266

        # Get the help text for the field
        name = field.name
267
        help = cls_docs.get(name, "").strip()
268
269
270
271
272
273
274
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
275
276
277
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
278
        if dataclass_cls is not None:
279
280
281
282
283
284
285
286

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
287
            kwargs[name]["help"] += f"\n\n{json_tip}"
288
        elif contains_type(type_hints, bool):
289
290
291
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
292
            kwargs[name].update(literal_to_kwargs(type_hints))
293
        elif contains_type(type_hints, tuple):
294
            kwargs[name].update(collection_to_kwargs(type_hints, tuple))
295
        elif contains_type(type_hints, list):
296
297
298
            kwargs[name].update(collection_to_kwargs(type_hints, list))
        elif contains_type(type_hints, set):
            kwargs[name].update(collection_to_kwargs(type_hints, set))
299
        elif contains_type(type_hints, int):
300
301
302
303
            if name == "max_model_len":
                kwargs[name]["type"] = human_readable_int_or_auto
                kwargs[name]["help"] += f"\n\n{human_readable_int_or_auto.__doc__}"
            elif name in ("max_num_batched_tokens", "kv_cache_memory_bytes"):
304
                kwargs[name]["type"] = human_readable_int
305
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
306
307
            else:
                kwargs[name]["type"] = int
308
309
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
310
311
312
313
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
314
            kwargs[name]["type"] = union_dict_and_str
315
        elif contains_type(type_hints, dict):
316
            kwargs[name]["type"] = parse_type(json.loads)
317
            kwargs[name]["help"] += f"\n\n{json_tip}"
318
319
320
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
321
322
            kwargs[name]["type"] = str
        else:
323
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
324

325
326
327
328
329
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

330
331
332
333
334
335
336
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
337
338


339
def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
340
341
    """Return argparse kwargs for the given Config dataclass.

342
343
344
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

345
346
347
348
349
350
351
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


352
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
353
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
354
    """Arguments for vLLM engine."""
355

356
    model: str = ModelConfig.model
357
    model_weights: str = ModelConfig.model_weights
358
    served_model_name: str | list[str] | None = ModelConfig.served_model_name
359
    tokenizer: str | None = ModelConfig.tokenizer
360
    hf_config_path: str | None = ModelConfig.hf_config_path
361
362
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
363
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
364
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
365
    tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
366
    trust_remote_code: bool = ModelConfig.trust_remote_code
367
368
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
    allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
369
    download_dir: str | None = LoadConfig.download_dir
370
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
371
    load_format: str | LoadFormats = LoadConfig.load_format
372
373
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
374
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
375
    seed: int = ModelConfig.seed
376
    max_model_len: int | None = ModelConfig.max_model_len
377
378
379
380
381
382
    cudagraph_capture_sizes: list[int] | None = (
        CompilationConfig.cudagraph_capture_sizes
    )
    max_cudagraph_capture_size: int | None = get_field(
        CompilationConfig, "max_cudagraph_capture_size"
    )
383
384
385
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
386
    distributed_executor_backend: (
387
        str | DistributedExecutorBackend | type[Executor] | None
388
    ) = ParallelConfig.distributed_executor_backend
389
    # number of P/D disaggregation (or other disaggregation) workers
390
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
391
392
393
394
    master_addr: str = ParallelConfig.master_addr
    master_port: int = ParallelConfig.master_port
    nnodes: int = ParallelConfig.nnodes
    node_rank: int = ParallelConfig.node_rank
395
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
396
    prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
397
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
398
    dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
399
    cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
400
    data_parallel_size: int = ParallelConfig.data_parallel_size
401
402
403
404
405
    data_parallel_rank: int | None = None
    data_parallel_start_rank: int | None = None
    data_parallel_size_local: int | None = None
    data_parallel_address: str | None = None
    data_parallel_rpc_port: int | None = None
406
    data_parallel_hybrid_lb: bool = False
407
    data_parallel_external_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
408
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
409
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
410
    all2all_backend: str = ParallelConfig.all2all_backend
411
    enable_dbo: bool = ParallelConfig.enable_dbo
412
    ubatch_size: int = ParallelConfig.ubatch_size
413
414
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
415
416
417
    disable_nccl_for_dp_synchronization: bool = (
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
418
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
419
    enable_eplb: bool = ParallelConfig.enable_eplb
420
    expert_placement_strategy: ExpertPlacementStrategy = (
421
        ParallelConfig.expert_placement_strategy
422
    )
423
424
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
425
    max_parallel_loading_workers: int | None = (
426
427
        ParallelConfig.max_parallel_loading_workers
    )
428
    block_size: BlockSize | None = CacheConfig.block_size
429
    enable_prefix_caching: bool | None = None
430
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
431
        CacheConfig.prefix_caching_hash_algo
432
    )
433
434
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
435
436
437
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
438
    kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
439
    max_num_batched_tokens: int | None = None
440
441
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
442
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
443
    max_num_seqs: int | None = None
444
    max_logprobs: int = ModelConfig.max_logprobs
445
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
446
    disable_log_stats: bool = False
447
    aggregate_engine_logging: bool = False
448
449
450
    revision: str | None = ModelConfig.revision
    code_revision: str | None = ModelConfig.code_revision
    hf_token: bool | str | None = ModelConfig.hf_token
451
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
452
    tokenizer_revision: str | None = ModelConfig.tokenizer_revision
453
    quantization: QuantizationMethods | None = ModelConfig.quantization
454
    enforce_eager: bool = ModelConfig.enforce_eager
455
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
456
    limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
457
458
        MultiModalConfig, "limit_per_prompt"
    )
459
    enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
460
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
461
462
463
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
464
    mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
465
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
466
    mm_processor_cache_type: MMCacheType | None = (
467
        MultiModalConfig.mm_processor_cache_type
468
469
    )
    mm_shm_cache_max_object_size_mb: int = (
470
        MultiModalConfig.mm_shm_cache_max_object_size_mb
471
    )
472
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
473
    mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
474
475
        MultiModalConfig.mm_encoder_attn_backend
    )
476
    io_processor_plugin: str | None = None
477
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
478
    video_pruning_rate: float = MultiModalConfig.video_pruning_rate
479
    # LoRA fields
480
    enable_lora: bool = False
481
482
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
483
    default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras
484
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
485
486
    max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
    lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
487
    enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
488

489
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
490
    num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
491
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
492
    ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
493

494
    enable_chunked_prefill: bool | None = None
495
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
496

497
    disable_hybrid_kv_cache_manager: bool | None = (
498
499
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
500

501
    structured_outputs_config: StructuredOutputsConfig = get_field(
502
503
        VllmConfig, "structured_outputs_config"
    )
504
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
505
    reasoning_parser_plugin: str | None = None
506

507
    logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
508

509
    speculative_config: dict[str, Any] | None = None
510

511
    show_hidden_metrics_for_version: str | None = (
512
        ObservabilityConfig.show_hidden_metrics_for_version
513
    )
514
515
    otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: list[DetailedTraceModules] | None = (
516
        ObservabilityConfig.collect_detailed_traces
517
    )
518
519
520
521
    kv_cache_metrics: bool = ObservabilityConfig.kv_cache_metrics
    kv_cache_metrics_sample: float = get_field(
        ObservabilityConfig, "kv_cache_metrics_sample"
    )
522
    cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
523
524
525
    enable_layerwise_nvtx_tracing: bool = (
        ObservabilityConfig.enable_layerwise_nvtx_tracing
    )
526
    enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
527
    enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
528
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
529
    scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
530

531
    pooler_config: PoolerConfig | None = ModelConfig.pooler_config
532
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
533
    attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
534
535
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
536

537
538
    profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config")

539
540
    kv_transfer_config: KVTransferConfig | None = None
    kv_events_config: KVEventsConfig | None = None
541

542
543
    ec_transfer_config: ECTransferConfig | None = None

544
545
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
546
547
548
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
549
    model_impl: str = ModelConfig.model_impl
550
    override_attention_dtype: str = ModelConfig.override_attention_dtype
551
    attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
552

553
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
554
555
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
556
    mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
557

558
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
559

560
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
561
    pt_load_map_location: str = LoadConfig.pt_load_map_location
562

563
    logits_processors: list[str | type[LogitsProcessor]] | None = (
564
565
        ModelConfig.logits_processors
    )
566
567
    """Custom logitproc types"""

568
    async_scheduling: bool | None = SchedulerConfig.async_scheduling
569

570
571
    stream_interval: int = SchedulerConfig.stream_interval

572
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
573
    optimization_level: OptimizationLevel = VllmConfig.optimization_level
574

575
576
577
578
    kv_offloading_size: float | None = CacheConfig.kv_offloading_size
    kv_offloading_backend: KVOffloadingBackend | None = (
        CacheConfig.kv_offloading_backend
    )
579
    tokens_only: bool = False
580

581
    def __post_init__(self):
582
583
584
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
585
        if isinstance(self.compilation_config, dict):
586
            self.compilation_config = CompilationConfig(**self.compilation_config)
587
588
        if isinstance(self.attention_config, dict):
            self.attention_config = AttentionConfig(**self.attention_config)
589
        if isinstance(self.eplb_config, dict):
590
            self.eplb_config = EPLBConfig(**self.eplb_config)
591
        # Setup plugins
592
        from vllm.plugins import load_general_plugins
593

594
        load_general_plugins()
595
        # when use hf offline,replace model and tokenizer id to local model path
596
597
598
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
            if model_id is not self.model:
                logger.info(
                    "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                    model_id,
                    self.model,
                )
            if self.tokenizer is not None:
                tokenizer_id = self.tokenizer
                self.tokenizer = get_model_path(self.tokenizer, self.tokenizer_revision)
                if tokenizer_id is not self.tokenizer:
                    logger.info(
                        "HF_HUB_OFFLINE is True, replace tokenizer_id [%s] "
                        "to tokenizer_path [%s]",
                        tokenizer_id,
                        self.tokenizer,
                    )
615
616

    @staticmethod
617
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
618
        """Shared CLI arguments for vLLM engine."""
619

620
        # Model arguments
621
622
623
624
625
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
626
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
627
            model_group.add_argument("--model", **model_kwargs["model"])
628
629
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
630
631
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
632
633
634
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
635
636
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
637
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
638
639
640
641
642
643
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
644
        model_group.add_argument("--revision", **model_kwargs["revision"])
645
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
646
647
648
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
649
650
651
652
653
654
655
656
657
658
659
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
660
661
662
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
663
664
665
666
667
668
669
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
670
671
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
        )
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
694
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
695
696
697
698
699
700
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
701
702
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
703
        )
704

705
706
707
708
709
710
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
711
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
712
713
714
715
716
717
718
719
720
721
722
723
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
724

725
726
727
728
729
730
731
732
733
734
        # Attention arguments
        attention_kwargs = get_kwargs(AttentionConfig)
        attention_group = parser.add_argument_group(
            title="AttentionConfig",
            description=AttentionConfig.__doc__,
        )
        attention_group.add_argument(
            "--attention-backend", **attention_kwargs["backend"]
        )

735
736
737
738
739
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
740
        )
741
        structured_outputs_group.add_argument(
742
            "--reasoning-parser",
743
            # Choices need to be validated after parsing to include plugins
744
745
            **structured_outputs_kwargs["reasoning_parser"],
        )
746
747
748
749
        structured_outputs_group.add_argument(
            "--reasoning-parser-plugin",
            **structured_outputs_kwargs["reasoning_parser_plugin"],
        )
750

751
        # Parallel arguments
752
753
754
755
756
757
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
758
            "--distributed-executor-backend",
759
760
            **parallel_kwargs["distributed_executor_backend"],
        )
761
        parallel_group.add_argument(
762
763
764
765
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
766
767
768
769
        parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
        parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
        parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
        parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
770
        parallel_group.add_argument(
771
772
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
773
        parallel_group.add_argument(
774
775
776
777
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
778
779
780
781
        parallel_group.add_argument(
            "--dcp-kv-cache-interleave-size",
            **parallel_kwargs["dcp_kv_cache_interleave_size"],
        )
782
783
784
785
786
787
788
789
790
        parallel_group.add_argument(
            "--cp-kv-cache-interleave-size",
            **parallel_kwargs["cp_kv_cache_interleave_size"],
        )
        parallel_group.add_argument(
            "--prefill-context-parallel-size",
            "-pcp",
            **parallel_kwargs["prefill_context_parallel_size"],
        )
791
792
793
794
795
796
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
797
            type=int,
798
799
800
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
801
        parallel_group.add_argument(
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
832
        parallel_group.add_argument(
833
834
835
836
837
838
839
840
            "--data-parallel-hybrid-lb",
            "-dph",
            **parallel_kwargs["data_parallel_hybrid_lb"],
        )
        parallel_group.add_argument(
            "--data-parallel-external-lb",
            "-dpe",
            **parallel_kwargs["data_parallel_external_lb"],
841
842
        )
        parallel_group.add_argument(
843
844
845
            "--enable-expert-parallel",
            "-ep",
            **parallel_kwargs["enable_expert_parallel"],
846
        )
847
848
849
        parallel_group.add_argument(
            "--all2all-backend", **parallel_kwargs["all2all_backend"]
        )
850
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
851
852
853
854
        parallel_group.add_argument(
            "--ubatch-size",
            **parallel_kwargs["ubatch_size"],
        )
855
856
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
857
858
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
859
860
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
861
862
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
863
864
865
866
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
867
868
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
869
870
        parallel_group.add_argument(
            "--expert-placement-strategy",
871
872
            **parallel_kwargs["expert_placement_strategy"],
        )
873

874
        parallel_group.add_argument(
875
            "--max-parallel-loading-workers",
876
877
            **parallel_kwargs["max_parallel_loading_workers"],
        )
878
        parallel_group.add_argument(
879
880
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
881
        parallel_group.add_argument(
882
            "--disable-custom-all-reduce",
883
884
885
886
887
888
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
889

890
891
892
893
894
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
895
        )
896
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
897
898
899
900
901
902
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
903
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
904
905
906
907
908
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
909
910
911
912
913
            "--enable-prefix-caching",
            **{
                **cache_kwargs["enable_prefix_caching"],
                "default": None,
            },
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
931
932
933
        cache_group.add_argument(
            "--mamba-block-size", **cache_kwargs["mamba_block_size"]
        )
934
935
936
937
938
939
        cache_group.add_argument(
            "--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
        )
        cache_group.add_argument(
            "--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
        )
940

941
        # Multimodal related configs
942
943
944
945
946
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
947
        multimodal_group.add_argument(
948
949
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
950
951
952
        multimodal_group.add_argument(
            "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
        )
953
954
955
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
956
957
958
959
960
961
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
962
        multimodal_group.add_argument(
963
964
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
965
966
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
967
968
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
969
        multimodal_group.add_argument(
970
971
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
972
973
974
975
        multimodal_group.add_argument(
            "--mm-encoder-attn-backend",
            **multimodal_kwargs["mm_encoder_attn_backend"],
        )
976
977
978
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
979
        multimodal_group.add_argument(
980
981
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
982

983
        multimodal_group.add_argument(
984
985
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
986

987
        # LoRA related configs
988
989
990
991
992
993
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
994
            "--enable-lora",
995
            action=argparse.BooleanOptionalAction,
996
997
            help="If True, enable handling of LoRA adapters.",
        )
998
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
999
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
1000
        lora_group.add_argument(
1001
            "--lora-dtype",
1002
1003
            **lora_kwargs["lora_dtype"],
        )
1004
1005
1006
1007
        lora_group.add_argument(
            "--enable-tower-connector-lora",
            **lora_kwargs["enable_tower_connector_lora"],
        )
1008
1009
1010
1011
1012
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
1013

1014
1015
1016
1017
1018
1019
1020
1021
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
1022
1023
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
1024
        observability_group.add_argument(
1025
1026
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
1027
1028
1029
1030
1031
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
1032
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
1033
1034
1035
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
1036
1037
            **observability_kwargs["collect_detailed_traces"],
        )
1038
1039
1040
1041
1042
1043
1044
        observability_group.add_argument(
            "--kv-cache-metrics", **observability_kwargs["kv_cache_metrics"]
        )
        observability_group.add_argument(
            "--kv-cache-metrics-sample",
            **observability_kwargs["kv_cache_metrics_sample"],
        )
1045
1046
1047
1048
        observability_group.add_argument(
            "--cudagraph-metrics",
            **observability_kwargs["cudagraph_metrics"],
        )
1049
1050
1051
1052
        observability_group.add_argument(
            "--enable-layerwise-nvtx-tracing",
            **observability_kwargs["enable_layerwise_nvtx_tracing"],
        )
1053
1054
1055
1056
        observability_group.add_argument(
            "--enable-mfu-metrics",
            **observability_kwargs["enable_mfu_metrics"],
        )
1057

1058
1059
1060
1061
1062
1063
1064
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
1065
1066
1067
1068
1069
            "--max-num-batched-tokens",
            **{
                **scheduler_kwargs["max_num_batched_tokens"],
                "default": None,
            },
1070
        )
1071
        scheduler_group.add_argument(
1072
1073
1074
1075
1076
            "--max-num-seqs",
            **{
                **scheduler_kwargs["max_num_seqs"],
                "default": None,
            },
1077
1078
1079
1080
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
1081
1082
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
1083
1084
            **scheduler_kwargs["max_long_partial_prefills"],
        )
1085
1086
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
1087
1088
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
1089
1090
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
1091
        scheduler_group.add_argument(
1092
1093
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
1094
        scheduler_group.add_argument(
1095
1096
1097
1098
1099
            "--enable-chunked-prefill",
            **{
                **scheduler_kwargs["enable_chunked_prefill"],
                "default": None,
            },
1100
1101
1102
1103
1104
1105
1106
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1107
1108
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1109
1110
1111
1112
1113
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1114
1115
1116
        scheduler_group.add_argument(
            "--stream-interval", **scheduler_kwargs["stream_interval"]
        )
1117

1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
        # Compilation arguments
        compilation_kwargs = get_kwargs(CompilationConfig)
        compilation_group = parser.add_argument_group(
            title="CompilationConfig",
            description=CompilationConfig.__doc__,
        )
        compilation_group.add_argument(
            "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
        )
        compilation_group.add_argument(
            "--max-cudagraph-capture-size",
            **compilation_kwargs["max_cudagraph_capture_size"],
        )

1132
        # vLLM arguments
1133
        vllm_kwargs = get_kwargs(VllmConfig)
1134
1135
1136
1137
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1138
1139
1140
1141
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1142
1143
1144
1145
1146
1147
1148
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
1149
1150
1151
        vllm_group.add_argument(
            "--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
        )
1152
        vllm_group.add_argument(
1153
            "--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
1154
        )
1155
1156
1157
        vllm_group.add_argument(
            "--attention-config", "-ac", **vllm_kwargs["attention_config"]
        )
1158
1159
1160
1161
1162
1163
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1164
        vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
1165
1166
1167
1168
        vllm_group.add_argument(
            "--optimization-level", **vllm_kwargs["optimization_level"]
        )

1169
        # Other arguments
1170
1171
1172
1173
1174
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1175

1176
1177
1178
1179
1180
1181
        parser.add_argument(
            "--aggregate-engine-logging",
            action="store_true",
            help="Log aggregate rather than per-engine statistics "
            "when using data parallelism.",
        )
1182
        return parser
1183
1184

    @classmethod
1185
    def from_cli_args(cls, args: argparse.Namespace):
1186
1187
1188
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1189
1190
1191
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1192
        return engine_args
1193

1194
    def create_model_config(self) -> ModelConfig:
1195
1196
        # gguf file needs a specific model loader
        if is_gguf(self.model):
1197
1198
            self.quantization = self.load_format = "gguf"

1199
1200
1201
1202
1203
1204
1205
        if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
            logger.warning(
                "The global random seed is set to %d. Since "
                "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
                "affect the random state of the Python process that "
                "launched vLLM.",
                self.seed,
1206
1207
            )

1208
        return ModelConfig(
1209
            model=self.model,
1210
            model_weights=self.model_weights,
1211
            hf_config_path=self.hf_config_path,
1212
1213
            runner=self.runner,
            convert=self.convert,
1214
1215
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
1216
            trust_remote_code=self.trust_remote_code,
1217
1218
            allowed_local_media_path=self.allowed_local_media_path,
            allowed_media_domains=self.allowed_media_domains,
1219
1220
1221
1222
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
1223
            hf_token=self.hf_token,
1224
            hf_overrides=self.hf_overrides,
1225
            tokenizer_revision=self.tokenizer_revision,
1226
1227
1228
1229
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_logprobs=self.max_logprobs,
1230
            logprobs_mode=self.logprobs_mode,
1231
            disable_sliding_window=self.disable_sliding_window,
1232
            disable_cascade_attn=self.disable_cascade_attn,
1233
            skip_tokenizer_init=self.skip_tokenizer_init,
1234
            enable_prompt_embeds=self.enable_prompt_embeds,
1235
            served_model_name=self.served_model_name,
1236
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1237
            enable_mm_embeds=self.enable_mm_embeds,
1238
            interleave_mm_strings=self.interleave_mm_strings,
1239
            media_io_kwargs=self.media_io_kwargs,
1240
            skip_mm_profiling=self.skip_mm_profiling,
1241
            config_format=self.config_format,
1242
            mm_processor_kwargs=self.mm_processor_kwargs,
1243
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1244
            mm_processor_cache_type=self.mm_processor_cache_type,
1245
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1246
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1247
            mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1248
            pooler_config=self.pooler_config,
1249
            logits_processor_pattern=self.logits_processor_pattern,
1250
            generation_config=self.generation_config,
1251
            override_generation_config=self.override_generation_config,
1252
            enable_sleep_mode=self.enable_sleep_mode,
1253
            model_impl=self.model_impl,
1254
            override_attention_dtype=self.override_attention_dtype,
1255
            logits_processors=self.logits_processors,
1256
            video_pruning_rate=self.video_pruning_rate,
1257
            io_processor_plugin=self.io_processor_plugin,
1258
        )
1259

1260
    def validate_tensorizer_args(self):
1261
1262
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1263
1264
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1265
1266
1267
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1268

1269
    def create_load_config(self) -> LoadConfig:
1270
1271
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1272

1273
1274
1275
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1276
1277
                    self.model_loader_extra_config.to_serializable()
                )
1278
            self.model_loader_extra_config["tensorizer_config"] = {}
1279
1280
1281
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1282
            self.validate_tensorizer_args()
1283

1284
1285
1286
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1287
            safetensors_load_strategy=self.safetensors_load_strategy,
1288
            device="cpu" if is_online_quantization(self.quantization) else None,
1289
1290
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1291
            use_tqdm_on_load=self.use_tqdm_on_load,
1292
            pt_load_map_location=self.pt_load_map_location,
1293
        )
1294

1295
1296
1297
1298
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
1299
    ) -> SpeculativeConfig | None:
1300
1301
1302
1303
1304
1305
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1306
        dictionary from the engine.
1307
1308
        """
        if self.speculative_config is None:
1309
            return None
1310

1311
1312
1313
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1314
1315
1316
1317
1318
1319
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
            }
        )
1320
        return SpeculativeConfig(**self.speculative_config)
1321

1322
1323
    def create_engine_config(
        self,
1324
        usage_context: UsageContext | None = None,
1325
        headless: bool = False,
1326
1327
1328
1329
    ) -> VllmConfig:
        """
        Create the VllmConfig.

1330
        NOTE: If VllmConfig is incompatible, we raise an error.
1331
        """
1332
        current_platform.pre_register_and_update()
1333

1334
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1335

1336
1337
        # Check if the model is a speculator and override model/tokenizer/config
        # BEFORE creating ModelConfig, so the config is created with the target model
1338
1339
1340
1341
        # Skip speculator detection for cloud storage models (eg: S3, GCS) since
        # HuggingFace cannot load configs directly from S3 URLs. S3 models can still
        # use speculators with explicit --speculative-config.
        if not is_cloud_storage(self.model):
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
            (self.model, self.tokenizer, self.speculative_config) = (
                maybe_override_with_speculators(
                    model=self.model,
                    tokenizer=self.tokenizer,
                    revision=self.revision,
                    trust_remote_code=self.trust_remote_code,
                    vllm_speculative_config=self.speculative_config,
                )
            )

1352
        model_config = self.create_model_config()
1353
        self.model = model_config.model
1354
        self.model_weights = model_config.model_weights
1355
1356
        self.tokenizer = model_config.tokenizer

1357
        self._check_feature_supported(model_config)
1358
1359
1360
1361
        self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
        self._set_default_max_num_seqs_and_batched_tokens_args(
            usage_context, model_config
        )
1362

1363
        sliding_window: int | None = None
1364
1365
1366
1367
1368
1369
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1370
1371
1372
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1373
        # reuses the GPUs of TP group, and split one TP group into
1374
        # tp_size//dcp_size DCP groups.
1375
        assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, (
1376
1377
1378
1379
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1380
1381
1382
1383
1384
        # Resolve "auto" kv_cache_dtype to actual value from model config
        resolved_cache_dtype = resolve_kv_cache_dtype_string(
            self.kv_cache_dtype, model_config
        )

1385
        cache_config = CacheConfig(
1386
            block_size=self.block_size,
1387
            gpu_memory_utilization=self.gpu_memory_utilization,
1388
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1389
            swap_space=self.swap_space,
1390
            cache_dtype=resolved_cache_dtype,
1391
            is_attention_free=model_config.is_attention_free,
1392
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1393
            sliding_window=sliding_window,
1394
            enable_prefix_caching=self.enable_prefix_caching,
1395
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1396
            cpu_offload_gb=self.cpu_offload_gb,
1397
            calculate_kv_scales=self.calculate_kv_scales,
1398
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1399
1400
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1401
            mamba_block_size=self.mamba_block_size,
1402
1403
            kv_offloading_size=self.kv_offloading_size,
            kv_offloading_backend=self.kv_offloading_backend,
1404
        )
1405

1406
1407
1408
1409
1410
1411
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1412

1413
            ray_runtime_env = ray.get_runtime_context().runtime_env
1414
1415
1416
1417
1418
1419
1420
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1421

1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1433
        assert not headless or not self.data_parallel_hybrid_lb, (
1434
1435
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
        assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
            "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
        )
        assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
            "nnodes > 1 is only supported with data_parallel_backend=mp"
        )
        inferred_data_parallel_rank = 0
        if self.nnodes > 1:
            world_size = (
                self.data_parallel_size
                * self.pipeline_parallel_size
                * self.tensor_parallel_size
            )
            world_size_within_dp = (
                self.pipeline_parallel_size * self.tensor_parallel_size
            )
            local_world_size = world_size // self.nnodes
            assert world_size % self.nnodes == 0, (
                f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
            )
            assert self.node_rank < self.nnodes, (
                f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
            )
            inferred_data_parallel_rank = (
                self.node_rank * local_world_size
            ) // world_size_within_dp
            if self.data_parallel_size > 1 and self.data_parallel_external_lb:
                self.data_parallel_rank = inferred_data_parallel_rank
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d for external lb",
                    self.data_parallel_rank,
                    self.node_rank,
                )
            elif self.data_parallel_size_local is None:
                # Infer data parallel size local for internal dplb:
                self.data_parallel_size_local = max(
                    local_world_size // world_size_within_dp, 1
                )
        data_parallel_external_lb = (
            self.data_parallel_external_lb or self.data_parallel_rank is not None
        )
1477
        # Local DP rank = 1, use pure-external LB.
1478
        if data_parallel_external_lb:
1479
            assert self.data_parallel_rank is not None, (
1480
                "data_parallel_rank or node_rank must be specified if "
1481
1482
                "data_parallel_external_lb is enable."
            )
1483
            assert self.data_parallel_size_local in (1, None), (
1484
1485
                "data_parallel_size_local must be 1 or None when data_parallel_rank "
                "is set"
1486
            )
1487
            data_parallel_size_local = 1
1488
1489
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1490
1491
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1492
1493
1494
1495
1496
1497
1498

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
1499
1500
1501
1502
1503
                logger.warning(
                    "data_parallel_hybrid_lb is not eligible when "
                    "data_parallel_size_local = 1, autoswitch to "
                    "data_parallel_external_lb."
                )
1504
1505
1506
1507
1508
1509
1510
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

1511
1512
1513
1514
1515
1516
1517
1518
1519
            self.data_parallel_rank = (
                self.data_parallel_start_rank or inferred_data_parallel_rank
            )
            if self.nnodes > 1:
                logger.info(
                    "Inferred data_parallel_rank %d from node_rank %d",
                    self.data_parallel_rank,
                    self.node_rank,
                )
1520
        else:
1521
            assert not self.data_parallel_hybrid_lb, (
1522
1523
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1524

1525
1526
1527
1528
1529
1530
1531
1532
1533
            if self.data_parallel_backend == "ray" and (
                envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
            ):
                # Data parallel size defaults to 1 if DP ranks are spanning
                # multiple nodes
                data_parallel_size_local = 1
            else:
                # Otherwise local DP size defaults to global DP size if not set
                data_parallel_size_local = self.data_parallel_size
1534
1535
1536

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1537
1538
1539
1540
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1541
1542
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1543
1544
1545
1546
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1547
1548
                    self.data_parallel_backend,
                )
1549
1550
1551
                data_parallel_address = (
                    self.master_addr or ParallelConfig.data_parallel_master_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1552
1553
        else:
            data_parallel_address = self.data_parallel_address
1554
1555
1556

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1557
        data_parallel_rpc_port = (
1558
            self.data_parallel_rpc_port
1559
1560
1561
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1562

1563
1564
1565
1566
        if self.tokens_only and not model_config.skip_tokenizer_init:
            model_config.skip_tokenizer_init = True
            logger.info("Skipping tokenizer initialization for tokens-only mode.")

1567
        parallel_config = ParallelConfig(
1568
1569
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1570
            prefill_context_parallel_size=self.prefill_context_parallel_size,
1571
            data_parallel_size=self.data_parallel_size,
1572
1573
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1574
            data_parallel_size_local=data_parallel_size_local,
1575
1576
1577
1578
            master_addr=self.master_addr,
            master_port=self.master_port,
            nnodes=self.nnodes,
            node_rank=self.node_rank,
1579
1580
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1581
            data_parallel_backend=self.data_parallel_backend,
1582
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1583
            is_moe_model=model_config.is_moe,
1584
            enable_expert_parallel=self.enable_expert_parallel,
1585
            all2all_backend=self.all2all_backend,
1586
            enable_dbo=self.enable_dbo,
1587
            ubatch_size=self.ubatch_size,
1588
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1589
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1590
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1591
            enable_eplb=self.enable_eplb,
1592
            eplb_config=self.eplb_config,
1593
            expert_placement_strategy=self.expert_placement_strategy,
1594
1595
1596
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1597
            ray_runtime_env=ray_runtime_env,
1598
            placement_group=placement_group,
1599
1600
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1601
            worker_extension_cls=self.worker_extension_cls,
1602
            decode_context_parallel_size=self.decode_context_parallel_size,
1603
            dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
1604
            cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
1605
1606
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1607
        )
1608

1609
        speculative_config = self.create_speculative_config(
1610
1611
1612
1613
            target_model_config=model_config,
            target_parallel_config=parallel_config,
        )

1614
        scheduler_config = SchedulerConfig(
1615
            runner_type=model_config.runner_type,
1616
1617
1618
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1619
            enable_chunked_prefill=self.enable_chunked_prefill,
1620
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1621
            is_multimodal_model=model_config.is_multimodal_model,
1622
            is_encoder_decoder=model_config.is_encoder_decoder,
1623
            policy=self.scheduling_policy,
1624
            scheduler_cls=self.scheduler_cls,
1625
1626
1627
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1628
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1629
            async_scheduling=self.async_scheduling,
1630
            stream_interval=self.stream_interval,
1631
        )
1632

1633
1634
1635
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_dtype=self.lora_dtype,
1646
                enable_tower_connector_lora=self.enable_tower_connector_lora,
1647
1648
1649
1650
1651
1652
1653
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1654

1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
        if (
            lora_config is not None
            and speculative_config is not None
            and scheduler_config.max_num_batched_tokens
            < (
                scheduler_config.max_num_seqs
                * (speculative_config.num_speculative_tokens + 1)
            )
        ):
            raise ValueError(
                "Consider increasing max_num_batched_tokens or "
                "decreasing num_speculative_tokens"
            )

1669
1670
1671
1672
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1673
1674
1675
1676
1677
1678
1679
1680
        # Attention config overrides
        attention_config = copy.deepcopy(self.attention_config)
        if self.attention_backend is not None:
            if attention_config.backend is not None:
                raise ValueError(
                    "attention_backend and attention_config.backend "
                    "are mutually exclusive"
                )
1681
1682
1683
1684
1685
1686
1687
            # Convert string to enum if needed (CLI parsing returns a string)
            if isinstance(self.attention_backend, str):
                attention_config.backend = AttentionBackendEnum[
                    self.attention_backend.upper()
                ]
            else:
                attention_config.backend = self.attention_backend
1688

1689
        load_config = self.create_load_config()
1690

1691
1692
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1693
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1694

1695
1696
1697
1698
1699
        if self.reasoning_parser_plugin:
            self.structured_outputs_config.reasoning_parser_plugin = (
                self.reasoning_parser_plugin
            )

1700
        observability_config = ObservabilityConfig(
1701
            show_hidden_metrics_for_version=self.show_hidden_metrics_for_version,
1702
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1703
            collect_detailed_traces=self.collect_detailed_traces,
1704
1705
            kv_cache_metrics=self.kv_cache_metrics,
            kv_cache_metrics_sample=self.kv_cache_metrics_sample,
1706
            cudagraph_metrics=self.cudagraph_metrics,
1707
            enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
1708
            enable_mfu_metrics=self.enable_mfu_metrics,
1709
            enable_mm_processor_stats=self.enable_mm_processor_stats,
1710
        )
1711

1712
        # Compilation config overrides
1713
        compilation_config = copy.deepcopy(self.compilation_config)
1714
        if self.cudagraph_capture_sizes is not None:
1715
            if compilation_config.cudagraph_capture_sizes is not None:
1716
1717
1718
1719
                raise ValueError(
                    "cudagraph_capture_sizes and compilation_config."
                    "cudagraph_capture_sizes are mutually exclusive"
                )
1720
            compilation_config.cudagraph_capture_sizes = self.cudagraph_capture_sizes
1721
        if self.max_cudagraph_capture_size is not None:
1722
            if compilation_config.max_cudagraph_capture_size is not None:
1723
1724
1725
1726
                raise ValueError(
                    "max_cudagraph_capture_size and compilation_config."
                    "max_cudagraph_capture_size are mutually exclusive"
                )
1727
            compilation_config.max_cudagraph_capture_size = (
1728
1729
                self.max_cudagraph_capture_size
            )
1730
        config = VllmConfig(
1731
1732
1733
1734
1735
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
1736
1737
            load_config=load_config,
            attention_config=attention_config,
1738
1739
            lora_config=lora_config,
            speculative_config=speculative_config,
1740
            structured_outputs_config=self.structured_outputs_config,
1741
            observability_config=observability_config,
1742
            compilation_config=compilation_config,
1743
            kv_transfer_config=self.kv_transfer_config,
1744
            kv_events_config=self.kv_events_config,
1745
            ec_transfer_config=self.ec_transfer_config,
1746
            profiler_config=self.profiler_config,
1747
            additional_config=self.additional_config,
1748
            optimization_level=self.optimization_level,
1749
        )
1750

1751
1752
        return config

1753
1754
    def _check_feature_supported(self, model_config: ModelConfig):
        """Raise an error if the feature is not supported."""
1755
        if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
1756
            _raise_unsupported_error(feature_name="--logits-processor-pattern")
1757
1758

        # No Concurrent Partial Prefills so far.
1759
1760
1761
1762
1763
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
1764
            _raise_unsupported_error(feature_name="Concurrent Partial Prefill")
1765

1766
        # N-gram, Medusa, and Eagle are supported for speculative decoding.
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
        if self.speculative_config is not None:
            # speculative_config could still be a dict at this point
            if isinstance(self.speculative_config, dict):
                method = self.speculative_config.get("method", None)
            else:
                method = self.speculative_config.method

            if method == "draft_model":
                raise NotImplementedError(
                    "Draft model speculative decoding is not supported yet. "
                    "Please consider using other speculative decoding methods "
1778
1779
                    "such as ngram, medusa, eagle, or mtp."
                )
1780

1781
        if self.pipeline_parallel_size > 1:
1782
1783
1784
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1785
            if not supports_pp and self.distributed_executor_backend not in (
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
1796
                _raise_unsupported_error(feature_name=name)
1797

1798
1799
1800
1801
1802
1803
    @classmethod
    def get_batch_defaults(
        cls,
        world_size: int,
    ) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
        from vllm.usage.usage_lib import UsageContext
1804

1805
1806
        default_max_num_batched_tokens: dict[UsageContext | None, int]
        default_max_num_seqs: dict[UsageContext | None, int]
1807

1808
1809
        # When no user override, set the default values based on the usage
        # context.
1810
        # Use different default values for different hardware.
1811
1812
1813
1814
1815
1816
1817

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1818
            device_memory = current_platform.get_device_total_memory()
1819
            device_name = current_platform.get_device_name().lower()
1820
1821
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1822
            device_memory = 0
1823
            device_name = ""
1824

1825
1826
1827
1828
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1829
            # For GPUs like H100 and MI300x, use larger default values.
1830
1831
1832
1833
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1834
1835
1836
1837
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1838
1839
1840
1841
1842
1843
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1844
1845
1846
1847
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1848

1849
1850
        # tpu specific default values.
        if current_platform.is_tpu():
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
            chip_name = current_platform.get_device_name()

            if chip_name == "V6E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 2048,
                    UsageContext.OPENAI_API_SERVER: 1024,
                }
            elif chip_name == "V5E":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 1024,
                    UsageContext.OPENAI_API_SERVER: 512,
                }
            elif chip_name == "V5P":
                default_max_num_batched_tokens = {
                    UsageContext.LLM_CLASS: 512,
                    UsageContext.OPENAI_API_SERVER: 256,
                }
1868

1869
1870
1871
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
1872
1873
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1874
1875
            }
            default_max_num_seqs = {
1876
1877
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1878
1879
            }

1880
1881
        return default_max_num_batched_tokens, default_max_num_seqs

1882
1883
    def _set_default_chunked_prefill_and_prefix_caching_args(
        self, model_config: ModelConfig
1884
    ) -> None:
1885
1886
        default_chunked_prefill = model_config.is_chunked_prefill_supported
        default_prefix_caching = model_config.is_prefix_caching_supported
1887
1888
1889
1890
1891
1892
1893
1894

        if self.enable_chunked_prefill is None:
            self.enable_chunked_prefill = default_chunked_prefill

            logger.debug(
                "%s chunked prefill by default",
                "Enabling" if default_chunked_prefill else "Disabling",
            )
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
        elif (
            model_config.runner_type == "generate"
            and not self.enable_chunked_prefill
            and default_chunked_prefill
        ):
            logger.warning_once(
                "This model does not officially support disabling chunked prefill. "
                "Disabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
                scope="local",
            )
1906
1907
1908
1909
        elif (
            model_config.runner_type == "pooling"
            and self.enable_chunked_prefill
            and not default_chunked_prefill
1910
        ):
1911
            logger.warning_once(
1912
1913
1914
                "This model does not officially support chunked prefill. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
1915
                scope="local",
1916
1917
1918
1919
1920
            )

        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = default_prefix_caching

1921
            logger.debug(
1922
1923
1924
1925
1926
1927
1928
1929
                "%s prefix caching by default",
                "Enabling" if default_prefix_caching else "Disabling",
            )
        elif (
            model_config.runner_type == "pooling"
            and self.enable_prefix_caching
            and not default_prefix_caching
        ):
1930
            logger.warning_once(
1931
1932
1933
                "This model does not officially support prefix caching. "
                "Enabling this manually may cause the engine to crash "
                "or produce incorrect outputs.",
1934
                scope="local",
1935
1936
            )

1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
        # Disable chunked prefill and prefix caching for:
        # POWER (ppc64le)/s390x/RISCV CPUs in V1
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.S390X,
            CpuArchEnum.RISCV,
        ):
            logger.info(
                "Chunked prefill is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
            self.enable_chunked_prefill = False
            logger.info(
                "Prefix caching is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
            self.enable_prefix_caching = False

    def _set_default_max_num_seqs_and_batched_tokens_args(
1958
1959
1960
        self,
        usage_context: UsageContext | None,
        model_config: ModelConfig,
1961
    ):
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
        world_size = self.pipeline_parallel_size * self.tensor_parallel_size
        (
            default_max_num_batched_tokens,
            default_max_num_seqs,
        ) = self.get_batch_defaults(world_size)

        orig_max_num_batched_tokens = self.max_num_batched_tokens
        orig_max_num_seqs = self.max_num_seqs

        if self.max_num_batched_tokens is None:
            self.max_num_batched_tokens = default_max_num_batched_tokens.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )

        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs.get(
                usage_context,
                SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
            )

        if orig_max_num_batched_tokens is None:
            if not self.enable_chunked_prefill:
                # If max_model_len is too short, use the default for higher throughput.
                self.max_num_batched_tokens = max(
                    model_config.max_model_len,
                    self.max_num_batched_tokens,
                )

            # When using default settings,
            # Ensure max_num_batched_tokens does not exceed model limit.
            # Some models (e.g., Whisper) have embeddings tied to max length.
            self.max_num_batched_tokens = min(
                self.max_num_seqs * model_config.max_model_len,
1996
1997
                self.max_num_batched_tokens,
            )
1998

1999
2000
2001
2002
            logger.debug(
                "Defaulting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens,
                usage_context.value if usage_context else None,
2003
            )
2004

2005
2006
2007
2008
        if orig_max_num_seqs is None:
            assert self.max_num_batched_tokens is not None  # For type checking
            self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)

2009
            logger.debug(
2010
                "Defaulting max_num_seqs to %d for %s usage context.",
2011
                self.max_num_seqs,
2012
                usage_context.value if usage_context else None,
2013
            )
2014

2015

2016
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
2017
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
2018
    """Arguments for asynchronous vLLM engine."""
2019

2020
2021
    enable_log_requests: bool = False

2022
    @staticmethod
2023
2024
2025
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
2026
        # Initialize plugin to update the parser, for example, The plugin may
2027
        # add a new kind of quantization method to --quantization argument or
2028
2029
        # a new device to --device argument.
        load_general_plugins()
2030
2031
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
2045
        current_platform.pre_register_and_update(parser)
2046
        return parser
2047
2048


2049
2050
2051
2052
2053
2054
def _raise_unsupported_error(feature_name: str):
    msg = (
        f"{feature_name} is not supported. We recommend to "
        f"remove {feature_name} from your config."
    )
    raise NotImplementedError(msg)
2055
2056


2057
def human_readable_int(value: str) -> int:
2058
2059
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
2060

2061
2062
2063
2064
2065
2066
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
2067

2068
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
2069
2070
    if match:
        decimal_multiplier = {
2071
2072
2073
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
2074
            "t": 10**12,
2075
2076
        }
        binary_multiplier = {
2077
2078
2079
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
2080
            "T": 2**40,
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
2093
2094
2095
2096
2097
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
2098
2099
2100

    # Regular plain number.
    return int(value)
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119


def human_readable_int_or_auto(value: str) -> int:
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    Also accepts -1 or 'auto' as a special value for auto-detection.

    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    - '-1' or 'auto' -> -1 (special value for auto-detection)
    """
    value = value.strip()

    if value == "-1" or value.lower() == "auto":
        return -1

    return human_readable_int(value)