arg_utils.py 80.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from collections.abc import Callable
11
from dataclasses import MISSING, dataclass, fields, is_dataclass
12
from itertools import permutations
13
from types import UnionType
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Literal,
19
    TypeAlias,
20
21
22
23
24
25
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
26

27
import huggingface_hub
28
import regex as re
29
import torch
30
from pydantic import TypeAdapter, ValidationError
31
from pydantic.fields import FieldInfo
32
from typing_extensions import TypeIs, deprecated
33

34
import vllm.envs as envs
35
from vllm.attention.backends.registry import _Backend
36
37
38
39
40
41
42
43
44
45
46
from vllm.config import (
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
    EPLBConfig,
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
    ModelConfig,
47
    MultiModalConfig,
48
49
50
51
52
53
54
55
56
    ObservabilityConfig,
    ParallelConfig,
    PoolerConfig,
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
    VllmConfig,
    get_attr_docs,
)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo
from vllm.config.device import Device
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
    TaskOption,
    TokenizerMode,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
72
from vllm.config.utils import get_field
73
from vllm.logger import init_logger
74
from vllm.platforms import CpuArchEnum, current_platform
75
from vllm.plugins import load_general_plugins
76
from vllm.ray.lazy_utils import is_ray_initialized
77
from vllm.reasoning import ReasoningParserManager
78
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
79
80
81
82
83
from vllm.transformers_utils.config import (
    get_model_path,
    is_interleaved,
    maybe_override_with_speculators,
)
84
from vllm.transformers_utils.utils import check_gguf_file
85
from vllm.utils import FlexibleArgumentParser, is_in_ray_actor
86
from vllm.utils.mem_constants import GiB_bytes
87
from vllm.utils.network_utils import get_ip
88
from vllm.v1.sample.logits_processor import LogitsProcessor
89

90
91
if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization import QuantizationMethods
92
    from vllm.model_executor.model_loader import LoadFormats
93
    from vllm.usage.usage_lib import UsageContext
94
    from vllm.v1.executor import Executor
95
else:
96
    Executor = Any
97
    QuantizationMethods = Any
98
    LoadFormats = Any
99
100
    UsageContext = Any

101
102
logger = init_logger(__name__)

103
104
# object is used to allow for special typing forms
T = TypeVar("T")
105
106
TypeHint: TypeAlias = type[Any] | object
TypeHintT: TypeAlias = type[T] | object
107

108

109
110
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
111
112
113
114
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
115
116
                f"Value {val} cannot be converted to {return_type}."
            ) from e
117

118
119
120
    return _parse_type


121
122
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]:
    def _optional_type(val: str) -> T | None:
123
124
125
126
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

127
    return _optional_type
128
129


130
def union_dict_and_str(val: str) -> str | dict[str, str] | None:
131
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
132
        return str(val)
133
    return optional_type(json.loads)(val)
134
135


136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


151
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
152
153
154
155
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
156
    type_hint = get_type(type_hints, Literal)
157
158
159
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
160
        raise ValueError(
161
            "All options must be of the same type. "
162
163
            f"Got {options} with types {[type(c) for c in options]}"
        )
164
165
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
166
167


168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]:
    type_hint = get_type(type_hints, type)
    types = get_args(type_hint)
    elem_type = types[0]

    # Handle Ellipsis
    assert all(t is elem_type for t in types if t is not Ellipsis), (
        f"All non-Ellipsis elements must be of the same type. Got {types}."
    )

    # Handle Union types
    if get_origin(elem_type) in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
        assert str in get_args(elem_type), (
            "If element can have multiple types, one must be 'str' "
            f"(i.e. 'list[int | str]'). Got {elem_type}."
        )
        elem_type = str

    return {
        "type": elem_type,
        "nargs": "+" if type is not tuple or Ellipsis in types else len(types),
    }


193
194
195
196
197
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


198
199
200
201
202
203
204
205
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
206
207
    elif origin in {Union, UnionType}:
        # Union for Union[X, Y] and UnionType for X | Y
208
209
210
211
212
213
214
215
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


216
217
218
219
def is_online_quantization(quantization: Any) -> bool:
    return quantization in ["inc"]


220
NEEDS_HELP = (
221
222
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
223
224
225
226
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


227
@functools.lru_cache(maxsize=30)
228
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
229
230
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
231
232
    kwargs = {}
    for field in fields(cls):
233
        # Get the set of possible types for the field
234
        type_hints: set[TypeHint] = get_type_hints(field.type)
235
236
237
238
239

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

240
        # Get the default value of the field
241
242
        if field.default is not MISSING:
            default = field.default
243
244
245
246
247
248
249
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
                default = (
                    default.default
                    if default.default_factory is None
                    else default.default_factory()
                )
250
        elif field.default_factory is not MISSING:
251
            default = field.default_factory()
252
253
254

        # Get the help text for the field
        name = field.name
255
        help = cls_docs.get(name, "").strip()
256
257
258
259
260
261
262
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
263
264
265
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
266
        if dataclass_cls is not None:
267
268
269
270
271
272
273
274

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
275
            kwargs[name]["help"] += f"\n\n{json_tip}"
276
        elif contains_type(type_hints, bool):
277
278
279
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
280
            kwargs[name].update(literal_to_kwargs(type_hints))
281
        elif contains_type(type_hints, tuple):
282
            kwargs[name].update(collection_to_kwargs(type_hints, tuple))
283
        elif contains_type(type_hints, list):
284
285
286
            kwargs[name].update(collection_to_kwargs(type_hints, list))
        elif contains_type(type_hints, set):
            kwargs[name].update(collection_to_kwargs(type_hints, set))
287
288
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
289
            # Special case for large integers
290
291
292
293
294
295
            human_readable_ints = {
                "max_model_len",
                "max_num_batched_tokens",
                "kv_cache_memory_bytes",
            }
            if name in human_readable_ints:
296
                kwargs[name]["type"] = human_readable_int
297
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
298
299
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
300
301
302
303
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
304
            kwargs[name]["type"] = union_dict_and_str
305
        elif contains_type(type_hints, dict):
306
            kwargs[name]["type"] = parse_type(json.loads)
307
            kwargs[name]["help"] += f"\n\n{json_tip}"
308
309
310
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
311
312
            kwargs[name]["type"] = str
        else:
313
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
314

315
316
317
318
319
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

320
321
322
323
324
325
326
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
327
328


329
def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
330
331
    """Return argparse kwargs for the given Config dataclass.

332
333
334
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

335
336
337
338
339
340
341
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


342
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
343
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
344
    """Arguments for vLLM engine."""
345

346
    model: str = ModelConfig.model
347
348
349
    served_model_name: str | list[str] | None = ModelConfig.served_model_name
    tokenizer: str | None = ModelConfig.tokenizer
    hf_config_path: str | None = ModelConfig.hf_config_path
350
351
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
352
    task: TaskOption | None = ModelConfig.task
353
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
354
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
355
356
357
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
358
359
    allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
    download_dir: str | None = LoadConfig.download_dir
360
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
361
    load_format: str | LoadFormats = LoadConfig.load_format
362
363
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
364
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
365
366
    seed: int | None = ModelConfig.seed
    max_model_len: int | None = ModelConfig.max_model_len
367
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes")
368
369
370
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
371
    distributed_executor_backend: (
372
        str | DistributedExecutorBackend | type[Executor] | None
373
    ) = ParallelConfig.distributed_executor_backend
374
    # number of P/D disaggregation (or other disaggregation) workers
375
376
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
377
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
378
    data_parallel_size: int = ParallelConfig.data_parallel_size
379
380
381
382
383
    data_parallel_rank: int | None = None
    data_parallel_start_rank: int | None = None
    data_parallel_size_local: int | None = None
    data_parallel_address: str | None = None
    data_parallel_rpc_port: int | None = None
384
    data_parallel_hybrid_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
385
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
386
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
387
    all2all_backend: str | None = ParallelConfig.all2all_backend
388
    enable_dbo: bool = ParallelConfig.enable_dbo
389
390
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
391
392
393
    disable_nccl_for_dp_synchronization: bool = (
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
394
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
395
    enable_eplb: bool = ParallelConfig.enable_eplb
396
    expert_placement_strategy: ExpertPlacementStrategy = (
397
        ParallelConfig.expert_placement_strategy
398
    )
399
400
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
401
402
403
404
    num_redundant_experts: int = EPLBConfig.num_redundant_experts
    eplb_window_size: int = EPLBConfig.window_size
    eplb_step_interval: int = EPLBConfig.step_interval
    eplb_log_balancedness: bool = EPLBConfig.log_balancedness
405
    max_parallel_loading_workers: int | None = (
406
407
        ParallelConfig.max_parallel_loading_workers
    )
408
409
    block_size: BlockSize | None = CacheConfig.block_size
    enable_prefix_caching: bool | None = CacheConfig.enable_prefix_caching
410
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
411
        CacheConfig.prefix_caching_hash_algo
412
    )
413
414
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
415
416
417
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
418
419
    kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
    max_num_batched_tokens: int | None = SchedulerConfig.max_num_batched_tokens
420
421
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
422
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
423
    max_num_seqs: int | None = SchedulerConfig.max_num_seqs
424
    max_logprobs: int = ModelConfig.max_logprobs
425
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
426
    disable_log_stats: bool = False
427
    aggregate_engine_logging: bool = False
428
429
    revision: str | None = ModelConfig.revision
    code_revision: str | None = ModelConfig.code_revision
430
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
431
432
    rope_theta: float | None = ModelConfig.rope_theta
    hf_token: bool | str | None = ModelConfig.hf_token
433
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
434
435
    tokenizer_revision: str | None = ModelConfig.tokenizer_revision
    quantization: QuantizationMethods | None = ModelConfig.quantization
436
    enforce_eager: bool = ModelConfig.enforce_eager
437
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
438
    limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field(
439
440
        MultiModalConfig, "limit_per_prompt"
    )
441
    enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds
442
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
443
444
445
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
446
    mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
447
    disable_mm_preprocessor_cache: bool = False  # DEPRECATED
448
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
449
    mm_processor_cache_type: MMCacheType | None = (
450
        MultiModalConfig.mm_processor_cache_type
451
452
    )
    mm_shm_cache_max_object_size_mb: int = (
453
        MultiModalConfig.mm_shm_cache_max_object_size_mb
454
    )
455
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
456
457
458
    mm_encoder_attn_backend: _Backend | str | None = (
        MultiModalConfig.mm_encoder_attn_backend
    )
459
    io_processor_plugin: str | None = None
460
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
461
    video_pruning_rate: float = MultiModalConfig.video_pruning_rate
462
    # LoRA fields
463
    enable_lora: bool = False
464
465
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
466
    default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras
467
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
468
469
    max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
    lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
470
471
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size

472
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
473
    num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
474
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
475
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
476
    ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
477

478
    enable_chunked_prefill: bool | None = SchedulerConfig.enable_chunked_prefill
479
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
480

481
    disable_hybrid_kv_cache_manager: bool = (
482
483
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
484

485
    structured_outputs_config: StructuredOutputsConfig = get_field(
486
487
        VllmConfig, "structured_outputs_config"
    )
488
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
489

490
    # Deprecated guided decoding fields
491
492
493
494
    guided_decoding_backend: str | None = None
    guided_decoding_disable_fallback: bool | None = None
    guided_decoding_disable_any_whitespace: bool | None = None
    guided_decoding_disable_additional_properties: bool | None = None
495

496
    logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern
497

498
    speculative_config: dict[str, Any] | None = None
499

500
    show_hidden_metrics_for_version: str | None = (
501
        ObservabilityConfig.show_hidden_metrics_for_version
502
    )
503
504
    otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: list[DetailedTraceModules] | None = (
505
        ObservabilityConfig.collect_detailed_traces
506
    )
507
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
508
    scheduler_cls: str | type[object] = SchedulerConfig.scheduler_cls
509

510
511
    pooler_config: PoolerConfig | None = ModelConfig.pooler_config
    override_pooler_config: dict | PoolerConfig | None = (
512
        ModelConfig.override_pooler_config
513
514
    )
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
515
516
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
517

518
519
    kv_transfer_config: KVTransferConfig | None = None
    kv_events_config: KVEventsConfig | None = None
520

521
522
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
523
524
525
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
526
    model_impl: str = ModelConfig.model_impl
527
    override_attention_dtype: str = ModelConfig.override_attention_dtype
528

529
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
530
531
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
532

533
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
534

535
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
536
    pt_load_map_location: str = LoadConfig.pt_load_map_location
537

538
539
    # DEPRECATED
    enable_multimodal_encoder_data_parallel: bool = False
540

541
    logits_processors: list[str | type[LogitsProcessor]] | None = (
542
543
        ModelConfig.logits_processors
    )
544
545
    """Custom logitproc types"""

546
547
    async_scheduling: bool = SchedulerConfig.async_scheduling

548
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
549

550
    def __post_init__(self):
551
552
553
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
554
        if isinstance(self.compilation_config, dict):
555
            self.compilation_config = CompilationConfig(**self.compilation_config)
556
        if isinstance(self.eplb_config, dict):
557
            self.eplb_config = EPLBConfig(**self.eplb_config)
558
        # Setup plugins
559
        from vllm.plugins import load_general_plugins
560

561
        load_general_plugins()
562
563
564
565
566
        # when use hf offline,replace model id to local model path
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
            logger.info(
567
568
569
570
                "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                model_id,
                self.model,
            )
571
572

    @staticmethod
573
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
574
        """Shared CLI arguments for vLLM engine."""
575

576
        # Model arguments
577
578
579
580
581
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
582
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
583
            model_group.add_argument("--model", **model_kwargs["model"])
584
585
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
586
        model_group.add_argument("--task", **model_kwargs["task"], deprecated=True)
587
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
588
589
590
591
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
592
593
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
594
595
596
597
598
599
600
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
601
        model_group.add_argument("--revision", **model_kwargs["revision"])
602
603
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"])
604
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
629
630
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--override-pooler-config",
            **model_kwargs["override_pooler_config"],
            deprecated=True,
        )
        model_group.add_argument(
            "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
        )
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
658
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
659
660
661
662
663
664
665
666
667
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
        )
668

669
670
671
672
673
674
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
675
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
676
677
678
679
680
681
682
683
684
685
686
687
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
688

689
690
691
692
693
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
694
        )
695
        structured_outputs_group.add_argument(
696
            "--reasoning-parser",
697
            # This choice is a special case because it's not static
698
            choices=list(ReasoningParserManager.reasoning_parsers),
699
700
            **structured_outputs_kwargs["reasoning_parser"],
        )
701
702
703
704
705
706
707
708
709
710
711
        # Deprecated guided decoding arguments
        for arg, type in [
            ("--guided-decoding-backend", str),
            ("--guided-decoding-disable-fallback", bool),
            ("--guided-decoding-disable-any-whitespace", bool),
            ("--guided-decoding-disable-additional-properties", bool),
        ]:
            structured_outputs_group.add_argument(
                arg,
                type=type,
                help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
712
713
                deprecated=True,
            )
714

715
        # Parallel arguments
716
717
718
719
720
721
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
722
            "--distributed-executor-backend",
723
724
            **parallel_kwargs["distributed_executor_backend"],
        )
725
        parallel_group.add_argument(
726
727
728
729
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
730
        parallel_group.add_argument(
731
732
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
733
        parallel_group.add_argument(
734
735
736
737
738
739
740
741
742
743
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
744
            type=int,
745
746
747
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
748
        parallel_group.add_argument(
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
779
        parallel_group.add_argument(
780
781
782
783
784
            "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
        )
        parallel_group.add_argument(
            "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
        )
785
786
787
        parallel_group.add_argument(
            "--all2all-backend", **parallel_kwargs["all2all_backend"]
        )
788
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
789
790
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
791
792
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
793
794
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
795
796
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
797
798
799
800
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
801
802
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
803
804
        parallel_group.add_argument(
            "--expert-placement-strategy",
805
806
            **parallel_kwargs["expert_placement_strategy"],
        )
807
808
809
        parallel_group.add_argument(
            "--num-redundant-experts",
            type=int,
810
811
812
            help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
            deprecated=True,
        )
813
814
815
816
        parallel_group.add_argument(
            "--eplb-window-size",
            type=int,
            help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
817
818
            deprecated=True,
        )
819
820
821
        parallel_group.add_argument(
            "--eplb-step-interval",
            type=int,
822
823
824
            help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
            deprecated=True,
        )
825
826
827
        parallel_group.add_argument(
            "--eplb-log-balancedness",
            action=argparse.BooleanOptionalAction,
828
829
830
            help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
            deprecated=True,
        )
831

832
        parallel_group.add_argument(
833
            "--max-parallel-loading-workers",
834
835
            **parallel_kwargs["max_parallel_loading_workers"],
        )
836
        parallel_group.add_argument(
837
838
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
839
        parallel_group.add_argument(
840
            "--disable-custom-all-reduce",
841
842
843
844
845
846
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
847
848
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
849
            action="store_true",
850
851
            deprecated=True,
        )
852

853
854
855
856
857
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
858
        )
859
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
860
861
862
863
864
865
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
866
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
            "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"]
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
890

891
        # Multimodal related configs
892
893
894
895
896
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
897
        multimodal_group.add_argument(
898
899
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
900
901
902
        multimodal_group.add_argument(
            "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"]
        )
903
904
905
906
907
908
909
910
911
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
912
        multimodal_group.add_argument(
913
914
            "--disable-mm-preprocessor-cache", action="store_true", deprecated=True
        )
915
        multimodal_group.add_argument(
916
917
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
918
919
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
920
921
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
922
        multimodal_group.add_argument(
923
924
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
925
926
927
928
        multimodal_group.add_argument(
            "--mm-encoder-attn-backend",
            **multimodal_kwargs["mm_encoder_attn_backend"],
        )
929
930
931
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
932
        multimodal_group.add_argument(
933
934
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
935

936
        multimodal_group.add_argument(
937
938
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
939

940
        # LoRA related configs
941
942
943
944
945
946
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
947
            "--enable-lora",
948
            action=argparse.BooleanOptionalAction,
949
950
            help="If True, enable handling of LoRA adapters.",
        )
951
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
952
953
954
955
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
        lora_group.add_argument(
            "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"]
        )
956
        lora_group.add_argument(
957
            "--lora-dtype",
958
959
            **lora_kwargs["lora_dtype"],
        )
960
961
962
963
964
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
965

966
967
968
969
970
971
972
973
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
974
975
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
976
        observability_group.add_argument(
977
978
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
979
980
981
982
983
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
984
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
985
986
987
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
988
989
            **observability_kwargs["collect_detailed_traces"],
        )
990

991
992
993
994
995
996
997
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
998
999
            "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"]
        )
1000
        scheduler_group.add_argument(
1001
1002
1003
1004
1005
            "--max-num-seqs", **scheduler_kwargs["max_num_seqs"]
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
1006
1007
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
1008
1009
1010
1011
1012
            **scheduler_kwargs["max_long_partial_prefills"],
        )
        scheduler_group.add_argument(
            "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"]
        )
1013
1014
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
1015
1016
1017
1018
1019
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
        scheduler_group.add_argument(
            "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"]
        )
1020
1021
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
1022
        scheduler_group.add_argument(
1023
1024
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
1025
        scheduler_group.add_argument(
1026
1027
1028
1029
1030
1031
1032
1033
            "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"]
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1034
1035
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1036
1037
1038
1039
1040
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1041
1042

        # vLLM arguments
1043
        vllm_kwargs = get_kwargs(VllmConfig)
1044
1045
1046
1047
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1048
1049
1050
1051
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument(
            "--compilation-config", "-O", **vllm_kwargs["compilation_config"]
        )
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1068

1069
        # Other arguments
1070
1071
1072
1073
1074
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1075

1076
1077
1078
1079
1080
1081
        parser.add_argument(
            "--aggregate-engine-logging",
            action="store_true",
            help="Log aggregate rather than per-engine statistics "
            "when using data parallelism.",
        )
1082
        return parser
1083
1084

    @classmethod
1085
    def from_cli_args(cls, args: argparse.Namespace):
1086
1087
1088
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1089
1090
1091
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1092
        return engine_args
1093

1094
    def create_model_config(self) -> ModelConfig:
1095
1096
1097
1098
1099
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
1100
1101
1102
1103
1104
1105
        if (
            not isinstance(self, AsyncEngineArgs)
            and envs.VLLM_CI_USE_S3
            and self.model in MODELS_ON_S3
            and self.load_format == "auto"
        ):
1106
1107
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"

1108
1109
1110
1111
        if self.disable_mm_preprocessor_cache:
            logger.warning(
                "`--disable-mm-preprocessor-cache` is deprecated "
                "and will be removed in v0.13. "
1112
1113
                "Please use `--mm-processor-cache-gb 0` instead.",
            )
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125

            self.mm_processor_cache_gb = 0
        elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
            logger.warning(
                "VLLM_MM_INPUT_CACHE_GIB` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-processor-cache-gb %d` instead.",
                envs.VLLM_MM_INPUT_CACHE_GIB,
            )

            self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB

1126
1127
1128
1129
        if self.enable_multimodal_encoder_data_parallel:
            logger.warning(
                "--enable-multimodal-encoder-data-parallel` is deprecated "
                "and will be removed in v0.13. "
1130
1131
                "Please use `--mm-encoder-tp-mode data` instead."
            )
1132
1133
1134

            self.mm_encoder_tp_mode = "data"

1135
        return ModelConfig(
1136
            model=self.model,
1137
            hf_config_path=self.hf_config_path,
1138
1139
            runner=self.runner,
            convert=self.convert,
1140
            task=self.task,
1141
            tokenizer=self.tokenizer,
1142
1143
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1144
            allowed_local_media_path=self.allowed_local_media_path,
1145
            allowed_media_domains=self.allowed_media_domains,
1146
1147
1148
1149
1150
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1151
            rope_theta=self.rope_theta,
1152
            hf_token=self.hf_token,
1153
            hf_overrides=self.hf_overrides,
1154
1155
1156
1157
1158
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_logprobs=self.max_logprobs,
1159
            logprobs_mode=self.logprobs_mode,
1160
            disable_sliding_window=self.disable_sliding_window,
1161
            disable_cascade_attn=self.disable_cascade_attn,
1162
            skip_tokenizer_init=self.skip_tokenizer_init,
1163
            enable_prompt_embeds=self.enable_prompt_embeds,
1164
            served_model_name=self.served_model_name,
1165
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1166
            enable_mm_embeds=self.enable_mm_embeds,
1167
            interleave_mm_strings=self.interleave_mm_strings,
1168
            media_io_kwargs=self.media_io_kwargs,
1169
            skip_mm_profiling=self.skip_mm_profiling,
1170
            config_format=self.config_format,
1171
            mm_processor_kwargs=self.mm_processor_kwargs,
1172
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1173
            mm_processor_cache_type=self.mm_processor_cache_type,
1174
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1175
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1176
            mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1177
            pooler_config=self.pooler_config,
1178
            override_pooler_config=self.override_pooler_config,
1179
            logits_processor_pattern=self.logits_processor_pattern,
1180
            generation_config=self.generation_config,
1181
            override_generation_config=self.override_generation_config,
1182
            enable_sleep_mode=self.enable_sleep_mode,
1183
            model_impl=self.model_impl,
1184
            override_attention_dtype=self.override_attention_dtype,
1185
            logits_processors=self.logits_processors,
1186
            video_pruning_rate=self.video_pruning_rate,
1187
            io_processor_plugin=self.io_processor_plugin,
1188
        )
1189

1190
    def validate_tensorizer_args(self):
1191
1192
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1193
1194
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1195
1196
1197
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1198

1199
    def create_load_config(self) -> LoadConfig:
1200
1201
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1202

1203
1204
1205
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1206
1207
                    self.model_loader_extra_config.to_serializable()
                )
1208
            self.model_loader_extra_config["tensorizer_config"] = {}
1209
1210
1211
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1212
            self.validate_tensorizer_args()
1213

1214
1215
1216
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1217
            safetensors_load_strategy=self.safetensors_load_strategy,
1218
            device="cpu" if is_online_quantization(self.quantization) else None,
1219
1220
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1221
            use_tqdm_on_load=self.use_tqdm_on_load,
1222
            pt_load_map_location=self.pt_load_map_location,
1223
        )
1224

1225
1226
1227
1228
1229
1230
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
1231
    ) -> SpeculativeConfig | None:
1232
1233
1234
1235
1236
1237
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1238
        dictionary from the engine.
1239
1240
        """
        if self.speculative_config is None:
1241
            return None
1242

1243
1244
1245
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1246
1247
1248
1249
1250
1251
1252
1253
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
                "enable_chunked_prefill": enable_chunked_prefill,
                "disable_log_stats": disable_log_stats,
            }
        )
1254
        return SpeculativeConfig(**self.speculative_config)
1255

1256
1257
    def create_engine_config(
        self,
1258
        usage_context: UsageContext | None = None,
1259
        headless: bool = False,
1260
1261
1262
1263
1264
1265
1266
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1267

1268
1269
1270
1271
1272
1273
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1274
        current_platform.pre_register_and_update()
1275

1276
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1277

1278
1279
1280
1281
        model_config = self.create_model_config()
        self.model = model_config.model
        self.tokenizer = model_config.tokenizer

1282
1283
1284
1285
1286
1287
1288
1289
1290
        (self.model, self.tokenizer, self.speculative_config) = (
            maybe_override_with_speculators(
                model=self.model,
                tokenizer=self.tokenizer,
                revision=self.revision,
                trust_remote_code=self.trust_remote_code,
                vllm_speculative_config=self.speculative_config,
            )
        )
1291

1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

1309
1310
        # Set default arguments for V1 Engine.
        self._set_default_args(usage_context, model_config)
1311
1312
        # Disable chunked prefill and prefix caching for:
        # POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.S390X,
            CpuArchEnum.ARM,
            CpuArchEnum.RISCV,
        ):
            logger.info(
                "Chunked prefill is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
1324
            self.enable_chunked_prefill = False
1325
1326
1327
1328
1329
1330
1331
            logger.info(
                "Prefix caching is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
            self.enable_prefix_caching = False

1332
1333
        assert self.enable_chunked_prefill is not None

1334
        sliding_window: int | None = None
1335
1336
1337
1338
1339
1340
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1341
1342
1343
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1344
        # reuses the GPUs of TP group, and split one TP group into
1345
        # tp_size//dcp_size DCP groups.
1346
        assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, (
1347
1348
1349
1350
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1351
        cache_config = CacheConfig(
1352
            block_size=self.block_size,
1353
            gpu_memory_utilization=self.gpu_memory_utilization,
1354
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1355
1356
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1357
            is_attention_free=model_config.is_attention_free,
1358
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1359
            sliding_window=sliding_window,
1360
            enable_prefix_caching=self.enable_prefix_caching,
1361
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1362
            cpu_offload_gb=self.cpu_offload_gb,
1363
            calculate_kv_scales=self.calculate_kv_scales,
1364
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1365
1366
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1367
        )
1368

1369
1370
1371
1372
1373
1374
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1375

1376
            ray_runtime_env = ray.get_runtime_context().runtime_env
1377
1378
1379
1380
1381
1382
1383
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1384

1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1396
        assert not headless or not self.data_parallel_hybrid_lb, (
1397
1398
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1399

1400
        data_parallel_external_lb = self.data_parallel_rank is not None
1401
        # Local DP rank = 1, use pure-external LB.
1402
1403
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
1404
1405
                "data_parallel_size_local must be 1 when data_parallel_rank is set"
            )
1406
            data_parallel_size_local = 1
1407
1408
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1409
1410
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

            self.data_parallel_rank = self.data_parallel_start_rank or 0
1426
        else:
1427
            assert not self.data_parallel_hybrid_lb, (
1428
1429
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1430

1431
1432
1433
1434
1435
1436
1437
1438
1439
            if self.data_parallel_backend == "ray" and (
                envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
            ):
                # Data parallel size defaults to 1 if DP ranks are spanning
                # multiple nodes
                data_parallel_size_local = 1
            else:
                # Otherwise local DP size defaults to global DP size if not set
                data_parallel_size_local = self.data_parallel_size
1440
1441
1442

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1443
1444
1445
1446
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1447
1448
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1449
1450
1451
1452
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1453
1454
                    self.data_parallel_backend,
                )
Rui Qiao's avatar
Rui Qiao committed
1455
1456
1457
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1458
1459
1460

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1461
        data_parallel_rpc_port = (
1462
            self.data_parallel_rpc_port
1463
1464
1465
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1466

1467
1468
        if self.async_scheduling:
            if self.pipeline_parallel_size > 1:
1469
1470
1471
                raise ValueError(
                    "Async scheduling is not supported with pipeline-parallel-size > 1."
                )
1472
1473
1474
1475
1476
1477

            # Currently, async scheduling does not support speculative decoding.
            # TODO(woosuk): Support it.
            if self.speculative_config is not None:
                raise ValueError(
                    "Currently, speculative decoding is not supported with "
1478
1479
                    "async scheduling."
                )
1480

1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
        # Forward the deprecated CLI args to the EPLB config.
        if self.num_redundant_experts is not None:
            self.eplb_config.num_redundant_experts = self.num_redundant_experts
        if self.eplb_window_size is not None:
            self.eplb_config.window_size = self.eplb_window_size
        if self.eplb_step_interval is not None:
            self.eplb_config.step_interval = self.eplb_step_interval
        if self.eplb_log_balancedness is not None:
            self.eplb_config.log_balancedness = self.eplb_log_balancedness

1491
        parallel_config = ParallelConfig(
1492
1493
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1494
            data_parallel_size=self.data_parallel_size,
1495
1496
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1497
1498
1499
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1500
            data_parallel_backend=self.data_parallel_backend,
1501
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1502
            enable_expert_parallel=self.enable_expert_parallel,
1503
            all2all_backend=self.all2all_backend,
1504
1505
            enable_dbo=self.enable_dbo,
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1506
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1507
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1508
            enable_eplb=self.enable_eplb,
1509
            eplb_config=self.eplb_config,
1510
            expert_placement_strategy=self.expert_placement_strategy,
1511
1512
1513
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1514
            ray_runtime_env=ray_runtime_env,
1515
            placement_group=placement_group,
1516
1517
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1518
            worker_extension_cls=self.worker_extension_cls,
1519
            decode_context_parallel_size=self.decode_context_parallel_size,
1520
1521
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1522
        )
1523

1524
1525
1526
1527
1528
1529
1530
1531
1532
        if self.async_scheduling and (
            parallel_config.distributed_executor_backend not in ("mp", "uni")
        ):
            raise ValueError(
                "Currently, async scheduling only supports `mp` or `uni` "
                "distributed executor backend, but you choose "
                f"`{parallel_config.distributed_executor_backend}`."
            )

1533
        speculative_config = self.create_speculative_config(
1534
1535
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1536
            enable_chunked_prefill=self.enable_chunked_prefill,
1537
            disable_log_stats=self.disable_log_stats,
1538
1539
        )

1540
1541
1542
1543
1544
        # make sure num_lookahead_slots is set appropriately depending on
        # whether speculative decoding is enabled
        num_lookahead_slots = self.num_lookahead_slots
        if speculative_config is not None:
            num_lookahead_slots = speculative_config.num_lookahead_slots
1545

1546
        scheduler_config = SchedulerConfig(
1547
            runner_type=model_config.runner_type,
1548
1549
1550
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1551
            cuda_graph_sizes=self.cuda_graph_sizes,
1552
            num_lookahead_slots=num_lookahead_slots,
1553
            enable_chunked_prefill=self.enable_chunked_prefill,
1554
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1555
            is_multimodal_model=model_config.is_multimodal_model,
1556
            is_encoder_decoder=model_config.is_encoder_decoder,
1557
            policy=self.scheduling_policy,
1558
            scheduler_cls=self.scheduler_cls,
1559
1560
1561
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1562
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1563
            async_scheduling=self.async_scheduling,
1564
        )
1565

1566
1567
1568
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_extra_vocab_size=self.lora_extra_vocab_size,
                lora_dtype=self.lora_dtype,
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1587

1588
1589
1590
1591
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1592
        load_config = self.create_load_config()
1593

1594
1595
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1596
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1597
1598
1599
1600

        # Forward the deprecated CLI args to the StructuredOutputsConfig
        so_config = self.structured_outputs_config
        if self.guided_decoding_backend is not None:
1601
            so_config.guided_decoding_backend = self.guided_decoding_backend
1602
        if self.guided_decoding_disable_fallback is not None:
1603
1604
1605
            so_config.guided_decoding_disable_fallback = (
                self.guided_decoding_disable_fallback
            )
1606
        if self.guided_decoding_disable_any_whitespace is not None:
1607
1608
1609
            so_config.guided_decoding_disable_any_whitespace = (
                self.guided_decoding_disable_any_whitespace
            )
1610
        if self.guided_decoding_disable_additional_properties is not None:
1611
1612
1613
            so_config.guided_decoding_disable_additional_properties = (
                self.guided_decoding_disable_additional_properties
            )
1614

1615
        observability_config = ObservabilityConfig(
1616
            show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version),
1617
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1618
            collect_detailed_traces=self.collect_detailed_traces,
1619
        )
1620

1621
        config = VllmConfig(
1622
1623
1624
1625
1626
1627
1628
1629
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
1630
            structured_outputs_config=self.structured_outputs_config,
1631
            observability_config=observability_config,
1632
            compilation_config=self.compilation_config,
1633
            kv_transfer_config=self.kv_transfer_config,
1634
            kv_events_config=self.kv_events_config,
1635
            additional_config=self.additional_config,
1636
        )
1637

1638
1639
        return config

1640
1641
1642
1643
1644
1645
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1646
1647
1648
1649
        if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
            _raise_or_fallback(
                feature_name="--logits-processor-pattern", recommend_to_remove=False
            )
1650
1651
1652
            return False

        # No Concurrent Partial Prefills so far.
1653
1654
1655
1656
1657
1658
1659
1660
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
            _raise_or_fallback(
                feature_name="Concurrent Partial Prefill", recommend_to_remove=False
            )
1661
1662
            return False

1663
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
        if self.speculative_config is not None:
            # speculative_config could still be a dict at this point
            if isinstance(self.speculative_config, dict):
                method = self.speculative_config.get("method", None)
            else:
                method = self.speculative_config.method

            if method == "draft_model":
                raise NotImplementedError(
                    "Draft model speculative decoding is not supported yet. "
                    "Please consider using other speculative decoding methods "
1675
1676
                    "such as ngram, medusa, eagle, or mtp."
                )
1677
1678

        V1_BACKENDS = [
1679
1680
            "FLASH_ATTN",
            "PALLAS",
1681
            "TRITON_ATTN",
1682
            "TRITON_MLA",
1683
            "CUTLASS_MLA",
1684
            "FLASHMLA",
1685
            "FLASH_ATTN_MLA",
1686
            "FLASHINFER",
1687
            "FLASHINFER_MLA",
1688
            "ROCM_AITER_MLA",
1689
            "TORCH_SDPA",
1690
            "FLEX_ATTENTION",
1691
            "TREE_ATTN",
1692
1693
            "XFORMERS",
            "ROCM_ATTN",
1694
            "ROCM_AITER_UNIFIED_ATTN",
1695
        ]
1696
1697
1698
1699
        if (
            envs.is_set("VLLM_ATTENTION_BACKEND")
            and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS
        ):
1700
1701
1702
1703
1704
1705
1706
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

        #############################################################
        # Experimental Features - allow users to opt in.

1707
        if self.pipeline_parallel_size > 1:
1708
1709
1710
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1711
            if not supports_pp and self.distributed_executor_backend not in (
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
                _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1723
                return False
1724

1725
1726
1727
1728
        if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
            _raise_or_fallback(
                feature_name="sliding window (CPU backend)", recommend_to_remove=False
            )
1729
1730
            return False

1731
1732
1733
1734
        #############################################################

        return True

1735
1736
1737
    def _set_default_args(
        self, usage_context: UsageContext, model_config: ModelConfig
    ) -> None:
1738
        """Set Default Arguments for V1 Engine."""
1739

1740
        # V1 uses chunked prefills and prefix caching by default
1741
1742
1743
1744
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
1745

1746
            if self.enable_prefix_caching is None:
1747
1748
1749
1750
1751
1752
                # Disable prefix caching default for hybrid models
                # since the feature is still experimental.
                if model_config.is_hybrid:
                    self.enable_prefix_caching = False
                else:
                    self.enable_prefix_caching = True
1753
1754
        else:
            pooling_type = model_config.pooler_config.pooling_type
1755
            is_causal = getattr(model_config.hf_config, "is_causal", True)
1756
1757
1758
1759
1760
            incremental_prefill_supported = (
                pooling_type is not None
                and pooling_type.lower() == "last"
                and is_causal
            )
1761

1762
            action = "Enabling" if incremental_prefill_supported else "Disabling"
1763
1764
1765
1766
1767
1768
1769
1770

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)

1771
1772
        # When no user override, set the default values based on the usage
        # context.
1773
        # Use different default values for different hardware.
1774
1775
1776
1777
1778
1779
1780

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1781
            device_memory = current_platform.get_device_total_memory()
1782
            device_name = current_platform.get_device_name().lower()
1783
1784
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1785
            device_memory = 0
1786

1787
1788
1789
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
1790
        from vllm.usage.usage_lib import UsageContext
1791

1792
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1793
            # For GPUs like H100 and MI300x, use larger default values.
1794
1795
1796
1797
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1798
1799
1800
1801
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1802
1803
1804
1805
1806
1807
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1808
1809
1810
1811
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1812

1813
1814
1815
1816
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
1817
1818
1819
                    "V6E": 2048,
                    "V5E": 1024,
                    "V5P": 512,
1820
1821
                },
                UsageContext.OPENAI_API_SERVER: {
1822
1823
1824
1825
                    "V6E": 1024,
                    "V5E": 512,
                    "V5P": 256,
                },
1826
1827
            }

1828
1829
        # cpu specific default values.
        if current_platform.is_cpu():
1830
            world_size = self.pipeline_parallel_size * self.tensor_parallel_size
1831
            default_max_num_batched_tokens = {
1832
1833
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1834
1835
            }
            default_max_num_seqs = {
1836
1837
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1838
1839
            }

1840
        use_context_value = usage_context.value if usage_context else None
1841
1842
1843
1844
        if (
            self.max_num_batched_tokens is None
            and usage_context in default_max_num_batched_tokens
        ):
1845
1846
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
1847
1848
1849
1850
                if chip_name in default_max_num_batched_tokens_tpu[usage_context]:
                    self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[
                        usage_context
                    ][chip_name]
1851
                else:
1852
1853
1854
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1855
            else:
1856
1857
1858
                if not self.enable_chunked_prefill:
                    self.max_num_batched_tokens = model_config.max_model_len
                else:
1859
1860
1861
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1862
            logger.debug(
1863
                "Setting max_num_batched_tokens to %d for %s usage context.",
1864
1865
1866
                self.max_num_batched_tokens,
                use_context_value,
            )
1867

1868
1869
1870
1871
1872
        if self.max_num_seqs is None and usage_context in default_max_num_seqs:
            self.max_num_seqs = min(
                default_max_num_seqs[usage_context],
                self.max_num_batched_tokens or sys.maxsize,
            )
1873

1874
1875
1876
1877
1878
            logger.debug(
                "Setting max_num_seqs to %d for %s usage context.",
                self.max_num_seqs,
                use_context_value,
            )
1879

1880

1881
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1882
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1883
    """Arguments for asynchronous vLLM engine."""
1884

1885
1886
1887
1888
1889
1890
    enable_log_requests: bool = False

    @property
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1891
1892
        "`enable_log_requests` instead."
    )
1893
1894
1895
1896
1897
1898
1899
    def disable_log_requests(self) -> bool:
        return not self.enable_log_requests

    @disable_log_requests.setter
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1900
1901
        "`enable_log_requests` instead."
    )
1902
1903
    def disable_log_requests(self, value: bool):
        self.enable_log_requests = not value
1904
1905

    @staticmethod
1906
1907
1908
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
1909
        # Initialize plugin to update the parser, for example, The plugin may
1910
        # add a new kind of quantization method to --quantization argument or
1911
1912
        # a new device to --device argument.
        load_general_plugins()
1913
1914
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
1928
        current_platform.pre_register_and_update(parser)
1929
        return parser
1930
1931


1932
1933
1934
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
1935
1936
            f"VLLM_USE_V1=1 is not supported with {feature_name}."
        )
1937
1938
1939
1940
1941
1942
1943
1944
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


1945
1946
1947
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1948

1949
1950
1951
1952
1953
1954
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
1955
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
1956
1957
    if match:
        decimal_multiplier = {
1958
1959
1960
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
1961
1962
        }
        binary_multiplier = {
1963
1964
1965
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
1978
1979
1980
1981
1982
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
1983
1984
1985

    # Regular plain number.
    return int(value)