arg_utils.py 80 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from dataclasses import MISSING, dataclass, fields, is_dataclass
11
from itertools import permutations
12
13
14
15
16
17
18
19
20
21
22
23
24
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Callable,
    Literal,
    Optional,
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
25

26
import huggingface_hub
27
import regex as re
28
import torch
29
from pydantic import TypeAdapter, ValidationError
30
from pydantic.fields import FieldInfo
31
from typing_extensions import TypeIs, deprecated
32

33
import vllm.envs as envs
34
35
36
37
38
39
40
41
42
43
44
from vllm.config import (
    CacheConfig,
    CompilationConfig,
    ConfigType,
    DeviceConfig,
    EPLBConfig,
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LoRAConfig,
    ModelConfig,
45
    MultiModalConfig,
46
47
48
49
50
51
52
53
54
    ObservabilityConfig,
    ParallelConfig,
    PoolerConfig,
    SchedulerConfig,
    SpeculativeConfig,
    StructuredOutputsConfig,
    VllmConfig,
    get_attr_docs,
)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo
from vllm.config.device import Device
from vllm.config.model import (
    ConvertOption,
    HfOverrides,
    LogprobsMode,
    ModelDType,
    RunnerOption,
    TaskOption,
    TokenizerMode,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
70
from vllm.config.utils import get_field
71
from vllm.logger import init_logger
72
from vllm.platforms import CpuArchEnum, current_platform
73
from vllm.plugins import load_general_plugins
74
from vllm.ray.lazy_utils import is_ray_initialized
75
from vllm.reasoning import ReasoningParserManager
76
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
77
78
79
80
81
from vllm.transformers_utils.config import (
    get_model_path,
    is_interleaved,
    maybe_override_with_speculators,
)
82
from vllm.transformers_utils.utils import check_gguf_file
83
from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor
84
from vllm.v1.sample.logits_processor import LogitsProcessor
85

86
87
88
if TYPE_CHECKING:
    from vllm.executor.executor_base import ExecutorBase
    from vllm.model_executor.layers.quantization import QuantizationMethods
89
    from vllm.model_executor.model_loader import LoadFormats
90
91
92
93
    from vllm.usage.usage_lib import UsageContext
else:
    ExecutorBase = Any
    QuantizationMethods = Any
94
    LoadFormats = Any
95
96
    UsageContext = Any

97
98
logger = init_logger(__name__)

99
100
101
102
103
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]

104

105
106
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
107
108
109
110
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
111
112
                f"Value {val} cannot be converted to {return_type}."
            ) from e
113

114
115
116
    return _parse_type


117
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
118
119
120
121
122
    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

123
    return _optional_type
124
125


126
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
127
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
128
        return str(val)
129
    return optional_type(json.loads)(val)
130
131


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


147
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
148
149
150
151
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
152
    type_hint = get_type(type_hints, Literal)
153
154
155
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
156
        raise ValueError(
157
            "All options must be of the same type. "
158
159
            f"Got {options} with types {[type(c) for c in options]}"
        )
160
161
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
162
163


164
165
166
167
168
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
    elif origin is Union:
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


186
187
188
189
def is_online_quantization(quantization: Any) -> bool:
    return quantization in ["inc"]


190
NEEDS_HELP = (
191
192
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
193
194
195
196
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


197
198
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
199
200
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
201
202
    kwargs = {}
    for field in fields(cls):
203
        # Get the set of possible types for the field
204
        type_hints: set[TypeHint] = get_type_hints(field.type)
205
206
207
208
209

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

210
        # Get the default value of the field
211
212
        if field.default is not MISSING:
            default = field.default
213
214
215
216
217
218
219
            # Handle pydantic.Field defaults
            if isinstance(default, FieldInfo):
                default = (
                    default.default
                    if default.default_factory is None
                    else default.default_factory()
                )
220
        elif field.default_factory is not MISSING:
221
            default = field.default_factory()
222
223
224

        # Get the help text for the field
        name = field.name
225
        help = cls_docs.get(name, "").strip()
226
227
228
229
230
231
232
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
233
234
235
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
236
        if dataclass_cls is not None:
237
238
239
240
241
242
243
244

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
245
            kwargs[name]["help"] += f"\n\n{json_tip}"
246
        elif contains_type(type_hints, bool):
247
248
249
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
250
            kwargs[name].update(literal_to_kwargs(type_hints))
251
252
253
254
255
256
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
257
258
                f"type. Got {types}."
            )
259
260
261
262
263
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
264
265
266
267
268
269
            list_type = types[0]
            if get_origin(list_type) is Union:
                msg = "List type must contain str if it is a Union."
                assert str in get_args(list_type), msg
                list_type = str
            kwargs[name]["type"] = list_type
270
271
272
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
273
            # Special case for large integers
274
275
276
277
278
279
            human_readable_ints = {
                "max_model_len",
                "max_num_batched_tokens",
                "kv_cache_memory_bytes",
            }
            if name in human_readable_ints:
280
                kwargs[name]["type"] = human_readable_int
281
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
282
283
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
284
285
286
287
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
288
            kwargs[name]["type"] = union_dict_and_str
289
        elif contains_type(type_hints, dict):
290
            kwargs[name]["type"] = parse_type(json.loads)
291
            kwargs[name]["help"] += f"\n\n{json_tip}"
292
293
294
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
295
296
            kwargs[name]["type"] = str
        else:
297
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
298

299
300
301
302
303
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

304
305
306
307
308
309
310
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
311
312


313
314
315
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    """Return argparse kwargs for the given Config dataclass.

316
317
318
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

319
320
321
322
323
324
325
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


326
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
327
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
328
    """Arguments for vLLM engine."""
329

330
    model: str = ModelConfig.model
331
    served_model_name: Optional[Union[str, list[str]]] = ModelConfig.served_model_name
332
333
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
334
335
336
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
    task: Optional[TaskOption] = ModelConfig.task
337
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
338
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
339
340
341
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
342
    allowed_media_domains: Optional[list[str]] = ModelConfig.allowed_media_domains
343
    download_dir: Optional[str] = LoadConfig.download_dir
344
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
345
    load_format: Union[str, LoadFormats] = LoadConfig.load_format
346
347
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
348
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
349
350
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
351
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes")
352
353
354
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
355
    distributed_executor_backend: Optional[
356
        Union[str, DistributedExecutorBackend, type[ExecutorBase]]
357
    ] = ParallelConfig.distributed_executor_backend
358
    # number of P/D disaggregation (or other disaggregation) workers
359
360
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
361
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
362
    data_parallel_size: int = ParallelConfig.data_parallel_size
363
    data_parallel_rank: Optional[int] = None
364
    data_parallel_start_rank: Optional[int] = None
365
366
367
    data_parallel_size_local: Optional[int] = None
    data_parallel_address: Optional[str] = None
    data_parallel_rpc_port: Optional[int] = None
368
    data_parallel_hybrid_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
369
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
370
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
371
    enable_dbo: bool = ParallelConfig.enable_dbo
372
373
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
374
375
376
    disable_nccl_for_dp_synchronization: bool = (
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
377
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
378
    enable_eplb: bool = ParallelConfig.enable_eplb
379
    expert_placement_strategy: ExpertPlacementStrategy = (
380
        ParallelConfig.expert_placement_strategy
381
    )
382
383
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
384
385
386
387
    num_redundant_experts: int = EPLBConfig.num_redundant_experts
    eplb_window_size: int = EPLBConfig.window_size
    eplb_step_interval: int = EPLBConfig.step_interval
    eplb_log_balancedness: bool = EPLBConfig.log_balancedness
388
389
390
    max_parallel_loading_workers: Optional[int] = (
        ParallelConfig.max_parallel_loading_workers
    )
391
392
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
393
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
394
        CacheConfig.prefix_caching_hash_algo
395
    )
396
397
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
398
399
400
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
401
    kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
402
    max_num_batched_tokens: Optional[int] = SchedulerConfig.max_num_batched_tokens
403
404
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
405
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
406
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
407
    max_logprobs: int = ModelConfig.max_logprobs
408
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
409
    disable_log_stats: bool = False
410
411
412
413
414
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
415
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
416
417
418
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
419
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
420
421
422
    limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = get_field(
        MultiModalConfig, "limit_per_prompt"
    )
423
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
424
425
426
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
427
    mm_processor_kwargs: Optional[dict[str, Any]] = MultiModalConfig.mm_processor_kwargs
428
    disable_mm_preprocessor_cache: bool = False  # DEPRECATED
429
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
430
    mm_processor_cache_type: Optional[MMCacheType] = (
431
        MultiModalConfig.mm_processor_cache_type
432
433
    )
    mm_shm_cache_max_object_size_mb: int = (
434
        MultiModalConfig.mm_shm_cache_max_object_size_mb
435
    )
436
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
437
    io_processor_plugin: Optional[str] = None
438
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
439
    video_pruning_rate: float = MultiModalConfig.video_pruning_rate
440
    # LoRA fields
441
    enable_lora: bool = False
442
443
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
444
    default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
445
446
447
448
449
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size

450
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
451
    num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
452
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
453
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
454
    ignore_patterns: Union[str, list[str]] = get_field(LoadConfig, "ignore_patterns")
455

456
    enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
457
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
458

459
    disable_hybrid_kv_cache_manager: bool = (
460
461
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
462

463
    structured_outputs_config: StructuredOutputsConfig = get_field(
464
465
        VllmConfig, "structured_outputs_config"
    )
466
467
468
469
470
471
472
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
    # Deprecated guided decoding fields
    guided_decoding_backend: Optional[str] = None
    guided_decoding_disable_fallback: Optional[bool] = None
    guided_decoding_disable_any_whitespace: Optional[bool] = None
    guided_decoding_disable_additional_properties: Optional[bool] = None

473
    logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern
474

475
    speculative_config: Optional[dict[str, Any]] = None
476

477
    show_hidden_metrics_for_version: Optional[str] = (
478
        ObservabilityConfig.show_hidden_metrics_for_version
479
480
481
    )
    otlp_traces_endpoint: Optional[str] = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: Optional[list[DetailedTraceModules]] = (
482
        ObservabilityConfig.collect_detailed_traces
483
    )
484
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
485
    scheduler_cls: Union[str, type[object]] = SchedulerConfig.scheduler_cls
486

487
    pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
488
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = (
489
        ModelConfig.override_pooler_config
490
491
    )
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
492
493
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
494

495
    kv_transfer_config: Optional[KVTransferConfig] = None
496
    kv_events_config: Optional[KVEventsConfig] = None
497

498
499
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
500
501
502
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
503
    model_impl: str = ModelConfig.model_impl
504
    override_attention_dtype: str = ModelConfig.override_attention_dtype
505

506
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
507
508
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
509

510
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
511

512
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
513
    pt_load_map_location: str = LoadConfig.pt_load_map_location
514

515
516
    # DEPRECATED
    enable_multimodal_encoder_data_parallel: bool = False
517

518
519
520
    logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = (
        ModelConfig.logits_processors
    )
521
522
    """Custom logitproc types"""

523
524
    async_scheduling: bool = SchedulerConfig.async_scheduling

525
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
526

527
    def __post_init__(self):
528
529
530
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
531
        if isinstance(self.compilation_config, dict):
532
            self.compilation_config = CompilationConfig(**self.compilation_config)
533
        if isinstance(self.eplb_config, dict):
534
            self.eplb_config = EPLBConfig(**self.eplb_config)
535
        # Setup plugins
536
        from vllm.plugins import load_general_plugins
537

538
        load_general_plugins()
539
540
541
542
543
        # when use hf offline,replace model id to local model path
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
            logger.info(
544
545
546
547
                "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                model_id,
                self.model,
            )
548
549

    @staticmethod
550
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
551
        """Shared CLI arguments for vLLM engine."""
552

553
        # Model arguments
554
555
556
557
558
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
559
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
560
            model_group.add_argument("--model", **model_kwargs["model"])
561
562
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
563
        model_group.add_argument("--task", **model_kwargs["task"], deprecated=True)
564
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
565
566
567
568
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
569
570
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
571
572
573
574
575
576
577
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
578
        model_group.add_argument("--revision", **model_kwargs["revision"])
579
580
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"])
581
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
606
607
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--override-pooler-config",
            **model_kwargs["override_pooler_config"],
            deprecated=True,
        )
        model_group.add_argument(
            "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
        )
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
635
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
636
637
638
639
640
641
642
643
644
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
        )
645

646
647
648
649
650
651
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
652
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
653
654
655
656
657
658
659
660
661
662
663
664
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
665

666
667
668
669
670
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
671
        )
672
        structured_outputs_group.add_argument(
673
            "--reasoning-parser",
674
            # This choice is a special case because it's not static
675
            choices=list(ReasoningParserManager.reasoning_parsers),
676
677
            **structured_outputs_kwargs["reasoning_parser"],
        )
678
679
680
681
682
683
684
685
686
687
688
        # Deprecated guided decoding arguments
        for arg, type in [
            ("--guided-decoding-backend", str),
            ("--guided-decoding-disable-fallback", bool),
            ("--guided-decoding-disable-any-whitespace", bool),
            ("--guided-decoding-disable-additional-properties", bool),
        ]:
            structured_outputs_group.add_argument(
                arg,
                type=type,
                help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
689
690
                deprecated=True,
            )
691

692
        # Parallel arguments
693
694
695
696
697
698
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
699
            "--distributed-executor-backend",
700
701
            **parallel_kwargs["distributed_executor_backend"],
        )
702
        parallel_group.add_argument(
703
704
705
706
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
707
        parallel_group.add_argument(
708
709
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
710
        parallel_group.add_argument(
711
712
713
714
715
716
717
718
719
720
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
721
            type=int,
722
723
724
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
725
        parallel_group.add_argument(
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
756
        parallel_group.add_argument(
757
758
759
760
761
762
            "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
        )
        parallel_group.add_argument(
            "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
        )
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
763
764
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
765
766
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
767
768
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
769
770
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
771
772
773
774
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
775
776
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
777
778
        parallel_group.add_argument(
            "--expert-placement-strategy",
779
780
            **parallel_kwargs["expert_placement_strategy"],
        )
781
782
783
        parallel_group.add_argument(
            "--num-redundant-experts",
            type=int,
784
785
786
            help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
            deprecated=True,
        )
787
788
789
790
        parallel_group.add_argument(
            "--eplb-window-size",
            type=int,
            help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
791
792
            deprecated=True,
        )
793
794
795
        parallel_group.add_argument(
            "--eplb-step-interval",
            type=int,
796
797
798
            help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
            deprecated=True,
        )
799
800
801
        parallel_group.add_argument(
            "--eplb-log-balancedness",
            action=argparse.BooleanOptionalAction,
802
803
804
            help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
            deprecated=True,
        )
805

806
        parallel_group.add_argument(
807
            "--max-parallel-loading-workers",
808
809
            **parallel_kwargs["max_parallel_loading_workers"],
        )
810
        parallel_group.add_argument(
811
812
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
813
        parallel_group.add_argument(
814
            "--disable-custom-all-reduce",
815
816
817
818
819
820
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
821
822
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
823
            action="store_true",
824
825
            deprecated=True,
        )
826

827
828
829
830
831
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
832
        )
833
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
834
835
836
837
838
839
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
840
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
            "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"]
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
864

865
        # Multimodal related configs
866
867
868
869
870
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
871
        multimodal_group.add_argument(
872
873
874
875
876
877
878
879
880
881
882
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
883
        multimodal_group.add_argument(
884
885
            "--disable-mm-preprocessor-cache", action="store_true", deprecated=True
        )
886
        multimodal_group.add_argument(
887
888
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
889
890
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
891
892
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
893
        multimodal_group.add_argument(
894
895
896
897
898
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
899
        multimodal_group.add_argument(
900
901
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
902

903
        multimodal_group.add_argument(
904
905
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
906

907
        # LoRA related configs
908
909
910
911
912
913
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
914
            "--enable-lora",
915
            action=argparse.BooleanOptionalAction,
916
917
            help="If True, enable handling of LoRA adapters.",
        )
918
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
919
920
921
922
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
        lora_group.add_argument(
            "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"]
        )
923
        lora_group.add_argument(
924
            "--lora-dtype",
925
926
            **lora_kwargs["lora_dtype"],
        )
927
928
929
930
931
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
932

933
934
935
936
937
938
939
940
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
941
942
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
943
        observability_group.add_argument(
944
945
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
946
947
948
949
950
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
951
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
952
953
954
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
955
956
            **observability_kwargs["collect_detailed_traces"],
        )
957

958
959
960
961
962
963
964
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
965
966
            "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"]
        )
967
        scheduler_group.add_argument(
968
969
970
971
972
            "--max-num-seqs", **scheduler_kwargs["max_num_seqs"]
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
973
974
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
975
976
977
978
979
            **scheduler_kwargs["max_long_partial_prefills"],
        )
        scheduler_group.add_argument(
            "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"]
        )
980
981
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
982
983
984
985
986
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
        scheduler_group.add_argument(
            "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"]
        )
987
988
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
989
        scheduler_group.add_argument(
990
991
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
992
        scheduler_group.add_argument(
993
994
995
996
997
998
999
1000
            "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"]
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
1001
1002
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
1003
1004
1005
1006
1007
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1008
1009

        # vLLM arguments
1010
        vllm_kwargs = get_kwargs(VllmConfig)
1011
1012
1013
1014
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1015
1016
1017
1018
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument(
            "--compilation-config", "-O", **vllm_kwargs["compilation_config"]
        )
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1035

1036
        # Other arguments
1037
1038
1039
1040
1041
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1042

1043
        return parser
1044
1045

    @classmethod
1046
    def from_cli_args(cls, args: argparse.Namespace):
1047
1048
1049
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1050
1051
1052
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1053
        return engine_args
1054

1055
    def create_model_config(self) -> ModelConfig:
1056
1057
1058
1059
1060
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
1061
1062
1063
1064
1065
1066
        if (
            not isinstance(self, AsyncEngineArgs)
            and envs.VLLM_CI_USE_S3
            and self.model in MODELS_ON_S3
            and self.load_format == "auto"
        ):
1067
1068
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"

1069
1070
1071
1072
        if self.disable_mm_preprocessor_cache:
            logger.warning(
                "`--disable-mm-preprocessor-cache` is deprecated "
                "and will be removed in v0.13. "
1073
1074
                "Please use `--mm-processor-cache-gb 0` instead.",
            )
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086

            self.mm_processor_cache_gb = 0
        elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
            logger.warning(
                "VLLM_MM_INPUT_CACHE_GIB` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-processor-cache-gb %d` instead.",
                envs.VLLM_MM_INPUT_CACHE_GIB,
            )

            self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB

1087
1088
1089
1090
        if self.enable_multimodal_encoder_data_parallel:
            logger.warning(
                "--enable-multimodal-encoder-data-parallel` is deprecated "
                "and will be removed in v0.13. "
1091
1092
                "Please use `--mm-encoder-tp-mode data` instead."
            )
1093
1094
1095

            self.mm_encoder_tp_mode = "data"

1096
        return ModelConfig(
1097
            model=self.model,
1098
            hf_config_path=self.hf_config_path,
1099
1100
            runner=self.runner,
            convert=self.convert,
1101
            task=self.task,
1102
            tokenizer=self.tokenizer,
1103
1104
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1105
            allowed_local_media_path=self.allowed_local_media_path,
1106
            allowed_media_domains=self.allowed_media_domains,
1107
1108
1109
1110
1111
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1112
            rope_theta=self.rope_theta,
1113
            hf_token=self.hf_token,
1114
            hf_overrides=self.hf_overrides,
1115
1116
1117
1118
1119
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_logprobs=self.max_logprobs,
1120
            logprobs_mode=self.logprobs_mode,
1121
            disable_sliding_window=self.disable_sliding_window,
1122
            disable_cascade_attn=self.disable_cascade_attn,
1123
            skip_tokenizer_init=self.skip_tokenizer_init,
1124
            enable_prompt_embeds=self.enable_prompt_embeds,
1125
            served_model_name=self.served_model_name,
1126
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1127
            interleave_mm_strings=self.interleave_mm_strings,
1128
            media_io_kwargs=self.media_io_kwargs,
1129
            skip_mm_profiling=self.skip_mm_profiling,
1130
            config_format=self.config_format,
1131
            mm_processor_kwargs=self.mm_processor_kwargs,
1132
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1133
            mm_processor_cache_type=self.mm_processor_cache_type,
1134
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1135
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1136
            pooler_config=self.pooler_config,
1137
            override_pooler_config=self.override_pooler_config,
1138
            logits_processor_pattern=self.logits_processor_pattern,
1139
            generation_config=self.generation_config,
1140
            override_generation_config=self.override_generation_config,
1141
            enable_sleep_mode=self.enable_sleep_mode,
1142
            model_impl=self.model_impl,
1143
            override_attention_dtype=self.override_attention_dtype,
1144
            logits_processors=self.logits_processors,
1145
            video_pruning_rate=self.video_pruning_rate,
1146
            io_processor_plugin=self.io_processor_plugin,
1147
        )
1148

1149
    def validate_tensorizer_args(self):
1150
1151
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1152
1153
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1154
1155
1156
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1157

1158
    def create_load_config(self) -> LoadConfig:
1159
1160
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1161

1162
1163
1164
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1165
1166
                    self.model_loader_extra_config.to_serializable()
                )
1167
            self.model_loader_extra_config["tensorizer_config"] = {}
1168
1169
1170
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1171
            self.validate_tensorizer_args()
1172

1173
1174
1175
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1176
            safetensors_load_strategy=self.safetensors_load_strategy,
1177
            device="cpu" if is_online_quantization(self.quantization) else None,
1178
1179
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1180
            use_tqdm_on_load=self.use_tqdm_on_load,
1181
            pt_load_map_location=self.pt_load_map_location,
1182
        )
1183

1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1197
        dictionary from the engine.
1198
1199
        """
        if self.speculative_config is None:
1200
            return None
1201

1202
1203
1204
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1205
1206
1207
1208
1209
1210
1211
1212
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
                "enable_chunked_prefill": enable_chunked_prefill,
                "disable_log_stats": disable_log_stats,
            }
        )
1213
        return SpeculativeConfig(**self.speculative_config)
1214

1215
1216
1217
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
1218
        headless: bool = False,
1219
1220
1221
1222
1223
1224
1225
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1226

1227
1228
1229
1230
1231
1232
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1233
        current_platform.pre_register_and_update()
1234

1235
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1236

1237
1238
1239
1240
        model_config = self.create_model_config()
        self.model = model_config.model
        self.tokenizer = model_config.tokenizer

1241
1242
1243
1244
1245
1246
1247
1248
1249
        (self.model, self.tokenizer, self.speculative_config) = (
            maybe_override_with_speculators(
                model=self.model,
                tokenizer=self.tokenizer,
                revision=self.revision,
                trust_remote_code=self.trust_remote_code,
                vllm_speculative_config=self.speculative_config,
            )
        )
1250

1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

1268
1269
        # Set default arguments for V1 Engine.
        self._set_default_args(usage_context, model_config)
1270
        # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.S390X,
            CpuArchEnum.ARM,
            CpuArchEnum.RISCV,
        ):
            logger.info(
                "Chunked prefill is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
1282
            self.enable_chunked_prefill = False
1283
1284
        assert self.enable_chunked_prefill is not None

1285
1286
1287
1288
1289
1290
1291
        sliding_window: Optional[int] = None
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1292
1293
1294
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1295
        # reuses the GPUs of TP group, and split one TP group into
1296
        # tp_size//dcp_size DCP groups.
1297
        assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, (
1298
1299
1300
1301
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1302
        cache_config = CacheConfig(
1303
            block_size=self.block_size,
1304
            gpu_memory_utilization=self.gpu_memory_utilization,
1305
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1306
1307
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1308
            is_attention_free=model_config.is_attention_free,
1309
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1310
            sliding_window=sliding_window,
1311
            enable_prefix_caching=self.enable_prefix_caching,
1312
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1313
            cpu_offload_gb=self.cpu_offload_gb,
1314
            calculate_kv_scales=self.calculate_kv_scales,
1315
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1316
1317
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1318
        )
1319

1320
1321
1322
1323
1324
1325
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1326

1327
            ray_runtime_env = ray.get_runtime_context().runtime_env
1328
1329
1330
1331
1332
1333
1334
            # Avoid logging sensitive environment variables
            sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {}
            if "env_vars" in sanitized_env:
                sanitized_env["env_vars"] = {
                    k: "***" for k in sanitized_env["env_vars"]
                }
            logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env)
1335

1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1347
        assert not headless or not self.data_parallel_hybrid_lb, (
1348
1349
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1350

1351
        data_parallel_external_lb = self.data_parallel_rank is not None
1352
        # Local DP rank = 1, use pure-external LB.
1353
1354
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
1355
1356
                "data_parallel_size_local must be 1 when data_parallel_rank is set"
            )
1357
            data_parallel_size_local = 1
1358
1359
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1360
1361
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

            self.data_parallel_rank = self.data_parallel_start_rank or 0
1377
        else:
1378
            assert not self.data_parallel_hybrid_lb, (
1379
1380
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1381

1382
1383
            # Local DP size defaults to global DP size if not set.
            data_parallel_size_local = self.data_parallel_size
1384
1385
1386

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1387
1388
1389
1390
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1391
1392
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1393
1394
1395
1396
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1397
1398
                    self.data_parallel_backend,
                )
Rui Qiao's avatar
Rui Qiao committed
1399
1400
1401
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1402
1403
1404

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1405
        data_parallel_rpc_port = (
1406
            self.data_parallel_rpc_port
1407
1408
1409
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1410

1411
1412
1413
1414
        if self.async_scheduling:
            # Async scheduling does not work with the uniprocess backend.
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "mp"
1415
1416
1417
1418
                logger.info(
                    "Defaulting to mp-based distributed executor "
                    "backend for async scheduling."
                )
1419
            if self.pipeline_parallel_size > 1:
1420
1421
1422
                raise ValueError(
                    "Async scheduling is not supported with pipeline-parallel-size > 1."
                )
1423
1424
1425
1426
1427
1428

            # Currently, async scheduling does not support speculative decoding.
            # TODO(woosuk): Support it.
            if self.speculative_config is not None:
                raise ValueError(
                    "Currently, speculative decoding is not supported with "
1429
1430
                    "async scheduling."
                )
1431

1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
        # Forward the deprecated CLI args to the EPLB config.
        if self.num_redundant_experts is not None:
            self.eplb_config.num_redundant_experts = self.num_redundant_experts
        if self.eplb_window_size is not None:
            self.eplb_config.window_size = self.eplb_window_size
        if self.eplb_step_interval is not None:
            self.eplb_config.step_interval = self.eplb_step_interval
        if self.eplb_log_balancedness is not None:
            self.eplb_config.log_balancedness = self.eplb_log_balancedness

1442
        parallel_config = ParallelConfig(
1443
1444
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1445
            data_parallel_size=self.data_parallel_size,
1446
1447
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1448
1449
1450
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1451
            data_parallel_backend=self.data_parallel_backend,
1452
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1453
            enable_expert_parallel=self.enable_expert_parallel,
1454
1455
            enable_dbo=self.enable_dbo,
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1456
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1457
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1458
            enable_eplb=self.enable_eplb,
1459
            eplb_config=self.eplb_config,
1460
            expert_placement_strategy=self.expert_placement_strategy,
1461
1462
1463
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1464
            ray_runtime_env=ray_runtime_env,
1465
            placement_group=placement_group,
1466
1467
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1468
            worker_extension_cls=self.worker_extension_cls,
1469
            decode_context_parallel_size=self.decode_context_parallel_size,
1470
1471
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1472
        )
1473

1474
        speculative_config = self.create_speculative_config(
1475
1476
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1477
            enable_chunked_prefill=self.enable_chunked_prefill,
1478
            disable_log_stats=self.disable_log_stats,
1479
1480
        )

1481
1482
1483
1484
1485
        # make sure num_lookahead_slots is set appropriately depending on
        # whether speculative decoding is enabled
        num_lookahead_slots = self.num_lookahead_slots
        if speculative_config is not None:
            num_lookahead_slots = speculative_config.num_lookahead_slots
1486

1487
        scheduler_config = SchedulerConfig(
1488
            runner_type=model_config.runner_type,
1489
1490
1491
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1492
            cuda_graph_sizes=self.cuda_graph_sizes,
1493
            num_lookahead_slots=num_lookahead_slots,
1494
            enable_chunked_prefill=self.enable_chunked_prefill,
1495
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1496
            is_multimodal_model=model_config.is_multimodal_model,
1497
            is_encoder_decoder=model_config.is_encoder_decoder,
1498
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray),
1499
            policy=self.scheduling_policy,
1500
            scheduler_cls=self.scheduler_cls,
1501
1502
1503
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1504
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1505
            async_scheduling=self.async_scheduling,
1506
        )
1507

1508
1509
1510
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_extra_vocab_size=self.lora_extra_vocab_size,
                lora_dtype=self.lora_dtype,
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1529

1530
1531
1532
1533
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1534
        load_config = self.create_load_config()
1535

1536
1537
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1538
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1539
1540
1541
1542

        # Forward the deprecated CLI args to the StructuredOutputsConfig
        so_config = self.structured_outputs_config
        if self.guided_decoding_backend is not None:
1543
            so_config.guided_decoding_backend = self.guided_decoding_backend
1544
        if self.guided_decoding_disable_fallback is not None:
1545
1546
1547
            so_config.guided_decoding_disable_fallback = (
                self.guided_decoding_disable_fallback
            )
1548
        if self.guided_decoding_disable_any_whitespace is not None:
1549
1550
1551
            so_config.guided_decoding_disable_any_whitespace = (
                self.guided_decoding_disable_any_whitespace
            )
1552
        if self.guided_decoding_disable_additional_properties is not None:
1553
1554
1555
            so_config.guided_decoding_disable_additional_properties = (
                self.guided_decoding_disable_additional_properties
            )
1556

1557
        observability_config = ObservabilityConfig(
1558
            show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version),
1559
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1560
            collect_detailed_traces=self.collect_detailed_traces,
1561
        )
1562

1563
        config = VllmConfig(
1564
1565
1566
1567
1568
1569
1570
1571
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
1572
            structured_outputs_config=self.structured_outputs_config,
1573
            observability_config=observability_config,
1574
            compilation_config=self.compilation_config,
1575
            kv_transfer_config=self.kv_transfer_config,
1576
            kv_events_config=self.kv_events_config,
1577
            additional_config=self.additional_config,
1578
        )
1579

1580
1581
        return config

1582
1583
1584
1585
1586
1587
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1588
1589
1590
1591
        if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
            _raise_or_fallback(
                feature_name="--logits-processor-pattern", recommend_to_remove=False
            )
1592
1593
            return False

1594
        # No Mamba or Encoder-Decoder so far.
1595
        if not model_config.is_v1_compatible:
1596
1597
1598
            _raise_or_fallback(
                feature_name=model_config.architectures, recommend_to_remove=False
            )
1599
1600
1601
            return False

        # No Concurrent Partial Prefills so far.
1602
1603
1604
1605
1606
1607
1608
1609
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
            _raise_or_fallback(
                feature_name="Concurrent Partial Prefill", recommend_to_remove=False
            )
1610
1611
            return False

1612
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
        if self.speculative_config is not None:
            # speculative_config could still be a dict at this point
            if isinstance(self.speculative_config, dict):
                method = self.speculative_config.get("method", None)
            else:
                method = self.speculative_config.method

            if method == "draft_model":
                raise NotImplementedError(
                    "Draft model speculative decoding is not supported yet. "
                    "Please consider using other speculative decoding methods "
1624
1625
                    "such as ngram, medusa, eagle, or mtp."
                )
1626
1627

        V1_BACKENDS = [
1628
1629
            "FLASH_ATTN",
            "PALLAS",
1630
            "TRITON_ATTN",
1631
            "TRITON_MLA",
1632
            "CUTLASS_MLA",
1633
            "FLASHMLA",
1634
            "FLASH_ATTN_MLA",
1635
            "FLASHINFER",
1636
            "FLASHINFER_MLA",
1637
            "ROCM_AITER_MLA",
1638
            "TORCH_SDPA",
1639
            "FLEX_ATTENTION",
1640
            "TREE_ATTN",
1641
1642
            "XFORMERS",
            "ROCM_ATTN",
1643
            "ROCM_AITER_UNIFIED_ATTN",
1644
        ]
1645
1646
1647
1648
        if (
            envs.is_set("VLLM_ATTENTION_BACKEND")
            and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS
        ):
1649
1650
1651
1652
1653
1654
1655
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

        #############################################################
        # Experimental Features - allow users to opt in.

1656
        if self.pipeline_parallel_size > 1:
1657
1658
1659
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1660
            if not supports_pp and self.distributed_executor_backend not in (
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
                _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1672
                return False
1673

1674
1675
1676
1677
        if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
            _raise_or_fallback(
                feature_name="sliding window (CPU backend)", recommend_to_remove=False
            )
1678
1679
            return False

1680
1681
1682
1683
        #############################################################

        return True

1684
1685
1686
    def _set_default_args(
        self, usage_context: UsageContext, model_config: ModelConfig
    ) -> None:
1687
        """Set Default Arguments for V1 Engine."""
1688

1689
1690
1691
1692
1693
        # V1 always uses chunked prefills and prefix caching
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
1694
1695
1696

            # TODO: When prefix caching supports prompt embeds inputs, this
            # check can be removed.
1697
            if self.enable_prompt_embeds and self.enable_prefix_caching is not False:
1698
1699
1700
                logger.warning(
                    "--enable-prompt-embeds and --enable-prefix-caching "
                    "are not supported together in V1. Prefix caching has "
1701
1702
                    "been disabled."
                )
1703
1704
                self.enable_prefix_caching = False

1705
            if self.enable_prefix_caching is None:
1706
1707
1708
1709
1710
1711
                # Disable prefix caching default for hybrid models
                # since the feature is still experimental.
                if model_config.is_hybrid:
                    self.enable_prefix_caching = False
                else:
                    self.enable_prefix_caching = True
1712
1713
        else:
            pooling_type = model_config.pooler_config.pooling_type
1714
            is_causal = getattr(model_config.hf_config, "is_causal", True)
1715
1716
1717
1718
1719
            incremental_prefill_supported = (
                pooling_type is not None
                and pooling_type.lower() == "last"
                and is_causal
            )
1720

1721
            action = "Enabling" if incremental_prefill_supported else "Disabling"
1722
1723
1724
1725
1726
1727
1728
1729

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)

1730
1731
1732
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1733
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1734

1735
1736
        # When no user override, set the default values based on the usage
        # context.
1737
        # Use different default values for different hardware.
1738
1739
1740
1741
1742
1743
1744

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1745
            device_memory = current_platform.get_device_total_memory()
1746
            device_name = current_platform.get_device_name().lower()
1747
1748
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1749
            device_memory = 0
1750

1751
1752
1753
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
1754
        from vllm.usage.usage_lib import UsageContext
1755

1756
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1757
            # For GPUs like H100 and MI300x, use larger default values.
1758
1759
1760
1761
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1762
1763
1764
1765
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1766
1767
1768
1769
1770
1771
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1772
1773
1774
1775
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1776

1777
1778
1779
1780
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
1781
1782
1783
                    "V6E": 2048,
                    "V5E": 1024,
                    "V5P": 512,
1784
1785
                },
                UsageContext.OPENAI_API_SERVER: {
1786
1787
1788
1789
                    "V6E": 1024,
                    "V5E": 512,
                    "V5P": 256,
                },
1790
1791
            }

1792
1793
        # cpu specific default values.
        if current_platform.is_cpu():
1794
            world_size = self.pipeline_parallel_size * self.tensor_parallel_size
1795
            default_max_num_batched_tokens = {
1796
1797
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1798
1799
            }
            default_max_num_seqs = {
1800
1801
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1802
1803
            }

1804
        use_context_value = usage_context.value if usage_context else None
1805
1806
1807
1808
        if (
            self.max_num_batched_tokens is None
            and usage_context in default_max_num_batched_tokens
        ):
1809
1810
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
1811
1812
1813
1814
                if chip_name in default_max_num_batched_tokens_tpu[usage_context]:
                    self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[
                        usage_context
                    ][chip_name]
1815
                else:
1816
1817
1818
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1819
            else:
1820
1821
1822
                if not self.enable_chunked_prefill:
                    self.max_num_batched_tokens = model_config.max_model_len
                else:
1823
1824
1825
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1826
            logger.debug(
1827
                "Setting max_num_batched_tokens to %d for %s usage context.",
1828
1829
1830
                self.max_num_batched_tokens,
                use_context_value,
            )
1831

1832
1833
1834
1835
1836
        if self.max_num_seqs is None and usage_context in default_max_num_seqs:
            self.max_num_seqs = min(
                default_max_num_seqs[usage_context],
                self.max_num_batched_tokens or sys.maxsize,
            )
1837

1838
1839
1840
1841
1842
            logger.debug(
                "Setting max_num_seqs to %d for %s usage context.",
                self.max_num_seqs,
                use_context_value,
            )
1843

1844

1845
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1846
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1847
    """Arguments for asynchronous vLLM engine."""
1848

1849
1850
1851
1852
1853
1854
    enable_log_requests: bool = False

    @property
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1855
1856
        "`enable_log_requests` instead."
    )
1857
1858
1859
1860
1861
1862
1863
    def disable_log_requests(self) -> bool:
        return not self.enable_log_requests

    @disable_log_requests.setter
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1864
1865
        "`enable_log_requests` instead."
    )
1866
1867
    def disable_log_requests(self, value: bool):
        self.enable_log_requests = not value
1868
1869

    @staticmethod
1870
1871
1872
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
1873
        # Initialize plugin to update the parser, for example, The plugin may
1874
        # add a new kind of quantization method to --quantization argument or
1875
1876
        # a new device to --device argument.
        load_general_plugins()
1877
1878
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
1892
        current_platform.pre_register_and_update(parser)
1893
        return parser
1894
1895


1896
1897
1898
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
1899
1900
            f"VLLM_USE_V1=1 is not supported with {feature_name}."
        )
1901
1902
1903
1904
1905
1906
1907
1908
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


1909
1910
1911
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1912

1913
1914
1915
1916
1917
1918
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
1919
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
1920
1921
    if match:
        decimal_multiplier = {
1922
1923
1924
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
1925
1926
        }
        binary_multiplier = {
1927
1928
1929
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
1942
1943
1944
1945
1946
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
1947
1948
1949

    # Regular plain number.
    return int(value)