arg_utils.py 79.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from dataclasses import MISSING, dataclass, fields, is_dataclass
11
from itertools import permutations
12
13
14
15
16
17
18
19
20
21
22
23
24
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Callable,
    Literal,
    Optional,
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
25

26
import huggingface_hub
27
import regex as re
28
import torch
29
from pydantic import TypeAdapter, ValidationError
30
from typing_extensions import TypeIs, deprecated
31

32
import vllm.envs as envs
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from vllm.config import (
    BlockSize,
    CacheConfig,
    CacheDType,
    CompilationConfig,
    ConfigType,
    ConvertOption,
    DetailedTraceModules,
    Device,
    DeviceConfig,
    DistributedExecutorBackend,
    EPLBConfig,
    HfOverrides,
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LogprobsMode,
    LoRAConfig,
    MambaDType,
    MMEncoderTPMode,
    ModelConfig,
    ModelDType,
    ObservabilityConfig,
    ParallelConfig,
    PoolerConfig,
    PrefixCachingHashAlgo,
    RunnerOption,
    SchedulerConfig,
    SchedulerPolicy,
    SpeculativeConfig,
    StructuredOutputsConfig,
    TaskOption,
    TokenizerMode,
    VllmConfig,
    get_attr_docs,
)
69
from vllm.config.multimodal import MMCacheType, MultiModalConfig
70
from vllm.config.parallel import ExpertPlacementStrategy
71
from vllm.config.utils import get_field
72
from vllm.logger import init_logger
73
from vllm.platforms import CpuArchEnum, current_platform
74
from vllm.plugins import load_general_plugins
75
from vllm.ray.lazy_utils import is_ray_initialized
76
from vllm.reasoning import ReasoningParserManager
77
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
78
79
80
81
82
from vllm.transformers_utils.config import (
    get_model_path,
    is_interleaved,
    maybe_override_with_speculators,
)
83
from vllm.transformers_utils.utils import check_gguf_file
84
from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor
85
from vllm.v1.sample.logits_processor import LogitsProcessor
86

87
88
89
if TYPE_CHECKING:
    from vllm.executor.executor_base import ExecutorBase
    from vllm.model_executor.layers.quantization import QuantizationMethods
90
    from vllm.model_executor.model_loader import LoadFormats
91
92
93
94
    from vllm.usage.usage_lib import UsageContext
else:
    ExecutorBase = Any
    QuantizationMethods = Any
95
    LoadFormats = Any
96
97
    UsageContext = Any

98
99
logger = init_logger(__name__)

100
101
102
103
104
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]

105

106
107
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
108
109
110
111
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
112
113
                f"Value {val} cannot be converted to {return_type}."
            ) from e
114

115
116
117
    return _parse_type


118
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
119
120
121
122
123
    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

124
    return _optional_type
125
126


127
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
128
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
129
        return str(val)
130
    return optional_type(json.loads)(val)
131
132


133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


148
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
149
150
151
152
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
153
    type_hint = get_type(type_hints, Literal)
154
155
156
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
157
        raise ValueError(
158
            "All options must be of the same type. "
159
160
            f"Got {options} with types {[type(c) for c in options]}"
        )
161
162
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
163
164


165
166
167
168
169
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
    elif origin is Union:
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


187
188
189
190
def is_online_quantization(quantization: Any) -> bool:
    return quantization in ["inc"]


191
NEEDS_HELP = (
192
193
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
194
195
196
197
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


198
199
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
200
201
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
202
203
    kwargs = {}
    for field in fields(cls):
204
        # Get the set of possible types for the field
205
        type_hints: set[TypeHint] = get_type_hints(field.type)
206
207
208
209
210

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

211
        # Get the default value of the field
212
213
214
        if field.default is not MISSING:
            default = field.default
        elif field.default_factory is not MISSING:
215
            default = field.default_factory()
216
217
218

        # Get the help text for the field
        name = field.name
219
        help = cls_docs.get(name, "").strip()
220
221
222
223
224
225
226
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
227
228
229
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
230
        if dataclass_cls is not None:
231
232
233
234
235
236
237
238

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
239
            kwargs[name]["help"] += f"\n\n{json_tip}"
240
        elif contains_type(type_hints, bool):
241
242
243
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
244
            kwargs[name].update(literal_to_kwargs(type_hints))
245
246
247
248
249
250
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
251
252
                f"type. Got {types}."
            )
253
254
255
256
257
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
258
259
260
261
262
263
            list_type = types[0]
            if get_origin(list_type) is Union:
                msg = "List type must contain str if it is a Union."
                assert str in get_args(list_type), msg
                list_type = str
            kwargs[name]["type"] = list_type
264
265
266
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
267
            # Special case for large integers
268
269
270
271
272
273
            human_readable_ints = {
                "max_model_len",
                "max_num_batched_tokens",
                "kv_cache_memory_bytes",
            }
            if name in human_readable_ints:
274
                kwargs[name]["type"] = human_readable_int
275
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
276
277
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
278
279
280
281
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
282
            kwargs[name]["type"] = union_dict_and_str
283
        elif contains_type(type_hints, dict):
284
            kwargs[name]["type"] = parse_type(json.loads)
285
            kwargs[name]["help"] += f"\n\n{json_tip}"
286
287
288
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
289
290
            kwargs[name]["type"] = str
        else:
291
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
292

293
294
295
296
297
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

298
299
300
301
302
303
304
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
305
306


307
308
309
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    """Return argparse kwargs for the given Config dataclass.

310
311
312
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

313
314
315
316
317
318
319
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


320
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
321
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
322
    """Arguments for vLLM engine."""
323

324
    model: str = ModelConfig.model
325
    served_model_name: Optional[Union[str, list[str]]] = ModelConfig.served_model_name
326
327
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
328
329
330
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
    task: Optional[TaskOption] = ModelConfig.task
331
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
332
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
333
334
335
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
336
    allowed_media_domains: Optional[list[str]] = ModelConfig.allowed_media_domains
337
    download_dir: Optional[str] = LoadConfig.download_dir
338
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
339
    load_format: Union[str, LoadFormats] = LoadConfig.load_format
340
341
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
342
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
343
344
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
345
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes")
346
347
348
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
349
    distributed_executor_backend: Optional[
350
        Union[str, DistributedExecutorBackend, type[ExecutorBase]]
351
    ] = ParallelConfig.distributed_executor_backend
352
    # number of P/D disaggregation (or other disaggregation) workers
353
354
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
355
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
356
    data_parallel_size: int = ParallelConfig.data_parallel_size
357
    data_parallel_rank: Optional[int] = None
358
    data_parallel_start_rank: Optional[int] = None
359
360
361
    data_parallel_size_local: Optional[int] = None
    data_parallel_address: Optional[str] = None
    data_parallel_rpc_port: Optional[int] = None
362
    data_parallel_hybrid_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
363
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
364
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
365
    enable_dbo: bool = ParallelConfig.enable_dbo
366
367
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
368
369
370
    disable_nccl_for_dp_synchronization: bool = (
        ParallelConfig.disable_nccl_for_dp_synchronization
    )
371
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
372
    enable_eplb: bool = ParallelConfig.enable_eplb
373
    expert_placement_strategy: ExpertPlacementStrategy = (
374
        ParallelConfig.expert_placement_strategy
375
    )
376
377
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
378
379
380
381
    num_redundant_experts: int = EPLBConfig.num_redundant_experts
    eplb_window_size: int = EPLBConfig.window_size
    eplb_step_interval: int = EPLBConfig.step_interval
    eplb_log_balancedness: bool = EPLBConfig.log_balancedness
382
383
384
    max_parallel_loading_workers: Optional[int] = (
        ParallelConfig.max_parallel_loading_workers
    )
385
386
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
387
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
388
        CacheConfig.prefix_caching_hash_algo
389
    )
390
391
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
392
393
394
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
395
    kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
396
    max_num_batched_tokens: Optional[int] = SchedulerConfig.max_num_batched_tokens
397
398
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
399
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
400
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
401
    max_logprobs: int = ModelConfig.max_logprobs
402
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
403
    disable_log_stats: bool = False
404
405
406
407
408
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
409
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
410
411
412
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
413
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
414
415
416
    limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = get_field(
        MultiModalConfig, "limit_per_prompt"
    )
417
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
418
419
420
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
421
    mm_processor_kwargs: Optional[dict[str, Any]] = MultiModalConfig.mm_processor_kwargs
422
    disable_mm_preprocessor_cache: bool = False  # DEPRECATED
423
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
424
    mm_processor_cache_type: Optional[MMCacheType] = (
425
        MultiModalConfig.mm_processor_cache_type
426
427
    )
    mm_shm_cache_max_object_size_mb: int = (
428
        MultiModalConfig.mm_shm_cache_max_object_size_mb
429
    )
430
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
431
    io_processor_plugin: Optional[str] = None
432
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
433
    video_pruning_rate: float = MultiModalConfig.video_pruning_rate
434
    # LoRA fields
435
    enable_lora: bool = False
436
437
438
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
439
    default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
440
441
442
443
444
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size

445
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
446
    num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
447
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
448
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
449
    ignore_patterns: Optional[Union[str, list[str]]] = LoadConfig.ignore_patterns
450

451
    enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
452
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
453

454
    disable_hybrid_kv_cache_manager: bool = (
455
456
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
457

458
    structured_outputs_config: StructuredOutputsConfig = get_field(
459
460
        VllmConfig, "structured_outputs_config"
    )
461
462
463
464
465
466
467
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
    # Deprecated guided decoding fields
    guided_decoding_backend: Optional[str] = None
    guided_decoding_disable_fallback: Optional[bool] = None
    guided_decoding_disable_any_whitespace: Optional[bool] = None
    guided_decoding_disable_additional_properties: Optional[bool] = None

468
    logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern
469

470
    speculative_config: Optional[dict[str, Any]] = None
471

472
    show_hidden_metrics_for_version: Optional[str] = (
473
        ObservabilityConfig.show_hidden_metrics_for_version
474
475
476
    )
    otlp_traces_endpoint: Optional[str] = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: Optional[list[DetailedTraceModules]] = (
477
        ObservabilityConfig.collect_detailed_traces
478
    )
479
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
480
    scheduler_cls: Union[str, type[object]] = SchedulerConfig.scheduler_cls
481

482
    pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
483
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = (
484
        ModelConfig.override_pooler_config
485
486
    )
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
487
488
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
489

490
    kv_transfer_config: Optional[KVTransferConfig] = None
491
    kv_events_config: Optional[KVEventsConfig] = None
492

493
494
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
495
496
497
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
498
    model_impl: str = ModelConfig.model_impl
499
    override_attention_dtype: str = ModelConfig.override_attention_dtype
500

501
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
502
503
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
504

505
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
506

507
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
508
    pt_load_map_location: str = LoadConfig.pt_load_map_location
509

510
511
    # DEPRECATED
    enable_multimodal_encoder_data_parallel: bool = False
512

513
514
515
    logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = (
        ModelConfig.logits_processors
    )
516
517
    """Custom logitproc types"""

518
519
    async_scheduling: bool = SchedulerConfig.async_scheduling

520
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
521

522
    def __post_init__(self):
523
524
525
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
526
        if isinstance(self.compilation_config, dict):
527
            self.compilation_config = CompilationConfig(**self.compilation_config)
528
        if isinstance(self.eplb_config, dict):
529
            self.eplb_config = EPLBConfig(**self.eplb_config)
530
        # Setup plugins
531
        from vllm.plugins import load_general_plugins
532

533
        load_general_plugins()
534
535
536
537
538
        # when use hf offline,replace model id to local model path
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
            logger.info(
539
540
541
542
                "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                model_id,
                self.model,
            )
543
544

    @staticmethod
545
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
546
        """Shared CLI arguments for vLLM engine."""
547

548
        # Model arguments
549
550
551
552
553
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
554
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
555
            model_group.add_argument("--model", **model_kwargs["model"])
556
557
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
558
        model_group.add_argument("--task", **model_kwargs["task"], deprecated=True)
559
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
560
561
562
563
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
564
565
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
566
567
568
569
570
571
572
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
573
        model_group.add_argument("--revision", **model_kwargs["revision"])
574
575
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"])
576
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
601
602
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--override-pooler-config",
            **model_kwargs["override_pooler_config"],
            deprecated=True,
        )
        model_group.add_argument(
            "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
        )
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
630
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
631
632
633
634
635
636
637
638
639
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
        )
640

641
642
643
644
645
646
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
647
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
648
649
650
651
652
653
654
655
656
657
658
659
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
660

661
662
663
664
665
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
666
        )
667
        structured_outputs_group.add_argument(
668
            "--reasoning-parser",
669
            # This choice is a special case because it's not static
670
            choices=list(ReasoningParserManager.reasoning_parsers),
671
672
            **structured_outputs_kwargs["reasoning_parser"],
        )
673
674
675
676
677
678
679
680
681
682
683
        # Deprecated guided decoding arguments
        for arg, type in [
            ("--guided-decoding-backend", str),
            ("--guided-decoding-disable-fallback", bool),
            ("--guided-decoding-disable-any-whitespace", bool),
            ("--guided-decoding-disable-additional-properties", bool),
        ]:
            structured_outputs_group.add_argument(
                arg,
                type=type,
                help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
684
685
                deprecated=True,
            )
686

687
        # Parallel arguments
688
689
690
691
692
693
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
694
            "--distributed-executor-backend",
695
696
            **parallel_kwargs["distributed_executor_backend"],
        )
697
        parallel_group.add_argument(
698
699
700
701
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
702
        parallel_group.add_argument(
703
704
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
705
        parallel_group.add_argument(
706
707
708
709
710
711
712
713
714
715
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
716
            type=int,
717
718
719
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
720
        parallel_group.add_argument(
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
751
        parallel_group.add_argument(
752
753
754
755
756
757
            "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
        )
        parallel_group.add_argument(
            "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
        )
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
758
759
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
760
761
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
762
763
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
764
765
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
766
767
768
769
        parallel_group.add_argument(
            "--disable-nccl-for-dp-synchronization",
            **parallel_kwargs["disable_nccl_for_dp_synchronization"],
        )
770
771
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
772
773
        parallel_group.add_argument(
            "--expert-placement-strategy",
774
775
            **parallel_kwargs["expert_placement_strategy"],
        )
776
777
778
        parallel_group.add_argument(
            "--num-redundant-experts",
            type=int,
779
780
781
            help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
            deprecated=True,
        )
782
783
784
785
        parallel_group.add_argument(
            "--eplb-window-size",
            type=int,
            help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
786
787
            deprecated=True,
        )
788
789
790
        parallel_group.add_argument(
            "--eplb-step-interval",
            type=int,
791
792
793
            help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
            deprecated=True,
        )
794
795
796
        parallel_group.add_argument(
            "--eplb-log-balancedness",
            action=argparse.BooleanOptionalAction,
797
798
799
            help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
            deprecated=True,
        )
800

801
        parallel_group.add_argument(
802
            "--max-parallel-loading-workers",
803
804
            **parallel_kwargs["max_parallel_loading_workers"],
        )
805
        parallel_group.add_argument(
806
807
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
808
        parallel_group.add_argument(
809
            "--disable-custom-all-reduce",
810
811
812
813
814
815
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
816
817
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
818
            action="store_true",
819
820
            deprecated=True,
        )
821

822
823
824
825
826
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
827
        )
828
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
829
830
831
832
833
834
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
835
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
            "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"]
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
859

860
        # Multimodal related configs
861
862
863
864
865
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
866
        multimodal_group.add_argument(
867
868
869
870
871
872
873
874
875
876
877
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
878
        multimodal_group.add_argument(
879
880
            "--disable-mm-preprocessor-cache", action="store_true", deprecated=True
        )
881
        multimodal_group.add_argument(
882
883
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
884
885
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
886
887
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
888
        multimodal_group.add_argument(
889
890
891
892
893
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
894
        multimodal_group.add_argument(
895
896
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
897

898
        multimodal_group.add_argument(
899
900
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
901

902
        # LoRA related configs
903
904
905
906
907
908
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
909
            "--enable-lora",
910
            action=argparse.BooleanOptionalAction,
911
912
913
            help="If True, enable handling of LoRA adapters.",
        )
        lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"])
914
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
915
916
917
918
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
        lora_group.add_argument(
            "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"]
        )
919
        lora_group.add_argument(
920
            "--lora-dtype",
921
922
            **lora_kwargs["lora_dtype"],
        )
923
924
925
926
927
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
928

929
930
931
932
933
934
935
936
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
937
938
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
939
        observability_group.add_argument(
940
941
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
942
943
944
945
946
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
947
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
948
949
950
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
951
952
            **observability_kwargs["collect_detailed_traces"],
        )
953

954
955
956
957
958
959
960
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
961
962
            "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"]
        )
963
        scheduler_group.add_argument(
964
965
966
967
968
            "--max-num-seqs", **scheduler_kwargs["max_num_seqs"]
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
969
970
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
971
972
973
974
975
            **scheduler_kwargs["max_long_partial_prefills"],
        )
        scheduler_group.add_argument(
            "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"]
        )
976
977
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
978
979
980
981
982
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
        scheduler_group.add_argument(
            "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"]
        )
983
984
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
985
        scheduler_group.add_argument(
986
987
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
988
        scheduler_group.add_argument(
989
990
991
992
993
994
995
996
            "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"]
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
997
998
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
999
1000
1001
1002
1003
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1004
1005

        # vLLM arguments
1006
        vllm_kwargs = get_kwargs(VllmConfig)
1007
1008
1009
1010
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1011
1012
1013
1014
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument(
            "--compilation-config", "-O", **vllm_kwargs["compilation_config"]
        )
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1031

1032
        # Other arguments
1033
1034
1035
1036
1037
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1038

1039
        return parser
1040
1041

    @classmethod
1042
    def from_cli_args(cls, args: argparse.Namespace):
1043
1044
1045
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1046
1047
1048
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1049
        return engine_args
1050

1051
    def create_model_config(self) -> ModelConfig:
1052
1053
1054
1055
1056
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
1057
1058
1059
1060
1061
1062
        if (
            not isinstance(self, AsyncEngineArgs)
            and envs.VLLM_CI_USE_S3
            and self.model in MODELS_ON_S3
            and self.load_format == "auto"
        ):
1063
1064
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"

1065
1066
1067
1068
        if self.disable_mm_preprocessor_cache:
            logger.warning(
                "`--disable-mm-preprocessor-cache` is deprecated "
                "and will be removed in v0.13. "
1069
1070
                "Please use `--mm-processor-cache-gb 0` instead.",
            )
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082

            self.mm_processor_cache_gb = 0
        elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
            logger.warning(
                "VLLM_MM_INPUT_CACHE_GIB` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-processor-cache-gb %d` instead.",
                envs.VLLM_MM_INPUT_CACHE_GIB,
            )

            self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB

1083
1084
1085
1086
        if self.enable_multimodal_encoder_data_parallel:
            logger.warning(
                "--enable-multimodal-encoder-data-parallel` is deprecated "
                "and will be removed in v0.13. "
1087
1088
                "Please use `--mm-encoder-tp-mode data` instead."
            )
1089
1090
1091

            self.mm_encoder_tp_mode = "data"

1092
        return ModelConfig(
1093
            model=self.model,
1094
            hf_config_path=self.hf_config_path,
1095
1096
            runner=self.runner,
            convert=self.convert,
1097
            task=self.task,
1098
            tokenizer=self.tokenizer,
1099
1100
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1101
            allowed_local_media_path=self.allowed_local_media_path,
1102
            allowed_media_domains=self.allowed_media_domains,
1103
1104
1105
1106
1107
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1108
            rope_theta=self.rope_theta,
1109
            hf_token=self.hf_token,
1110
            hf_overrides=self.hf_overrides,
1111
1112
1113
1114
1115
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_logprobs=self.max_logprobs,
1116
            logprobs_mode=self.logprobs_mode,
1117
            disable_sliding_window=self.disable_sliding_window,
1118
            disable_cascade_attn=self.disable_cascade_attn,
1119
            skip_tokenizer_init=self.skip_tokenizer_init,
1120
            enable_prompt_embeds=self.enable_prompt_embeds,
1121
            served_model_name=self.served_model_name,
1122
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1123
            interleave_mm_strings=self.interleave_mm_strings,
1124
            media_io_kwargs=self.media_io_kwargs,
1125
            skip_mm_profiling=self.skip_mm_profiling,
1126
            config_format=self.config_format,
1127
            mm_processor_kwargs=self.mm_processor_kwargs,
1128
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1129
            mm_processor_cache_type=self.mm_processor_cache_type,
1130
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1131
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1132
            pooler_config=self.pooler_config,
1133
            override_pooler_config=self.override_pooler_config,
1134
            logits_processor_pattern=self.logits_processor_pattern,
1135
            generation_config=self.generation_config,
1136
            override_generation_config=self.override_generation_config,
1137
            enable_sleep_mode=self.enable_sleep_mode,
1138
            model_impl=self.model_impl,
1139
            override_attention_dtype=self.override_attention_dtype,
1140
            logits_processors=self.logits_processors,
1141
            video_pruning_rate=self.video_pruning_rate,
1142
            io_processor_plugin=self.io_processor_plugin,
1143
        )
1144

1145
    def validate_tensorizer_args(self):
1146
1147
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1148
1149
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1150
1151
1152
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1153

1154
    def create_load_config(self) -> LoadConfig:
1155
1156
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1157

1158
1159
1160
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1161
1162
                    self.model_loader_extra_config.to_serializable()
                )
1163
            self.model_loader_extra_config["tensorizer_config"] = {}
1164
1165
1166
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1167
            self.validate_tensorizer_args()
1168

1169
1170
1171
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1172
            safetensors_load_strategy=self.safetensors_load_strategy,
1173
            device="cpu" if is_online_quantization(self.quantization) else None,
1174
1175
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1176
            use_tqdm_on_load=self.use_tqdm_on_load,
1177
            pt_load_map_location=self.pt_load_map_location,
1178
        )
1179

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1193
        dictionary from the engine.
1194
1195
        """
        if self.speculative_config is None:
1196
            return None
1197

1198
1199
1200
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1201
1202
1203
1204
1205
1206
1207
1208
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
                "enable_chunked_prefill": enable_chunked_prefill,
                "disable_log_stats": disable_log_stats,
            }
        )
1209
        return SpeculativeConfig(**self.speculative_config)
1210

1211
1212
1213
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
1214
        headless: bool = False,
1215
1216
1217
1218
1219
1220
1221
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1222

1223
1224
1225
1226
1227
1228
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1229
        current_platform.pre_register_and_update()
1230

1231
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1232

1233
1234
1235
1236
        model_config = self.create_model_config()
        self.model = model_config.model
        self.tokenizer = model_config.tokenizer

1237
1238
1239
1240
1241
1242
1243
1244
1245
        (self.model, self.tokenizer, self.speculative_config) = (
            maybe_override_with_speculators(
                model=self.model,
                tokenizer=self.tokenizer,
                revision=self.revision,
                trust_remote_code=self.trust_remote_code,
                vllm_speculative_config=self.speculative_config,
            )
        )
1246

1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

1264
1265
        # Set default arguments for V1 Engine.
        self._set_default_args(usage_context, model_config)
1266
        # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.S390X,
            CpuArchEnum.ARM,
            CpuArchEnum.RISCV,
        ):
            logger.info(
                "Chunked prefill is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
1278
            self.enable_chunked_prefill = False
1279
1280
        assert self.enable_chunked_prefill is not None

1281
1282
1283
1284
1285
1286
1287
        sliding_window: Optional[int] = None
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1288
1289
1290
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1291
        # reuses the GPUs of TP group, and split one TP group into
1292
        # tp_size//dcp_size DCP groups.
1293
        assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, (
1294
1295
1296
1297
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1298
        cache_config = CacheConfig(
1299
            block_size=self.block_size,
1300
            gpu_memory_utilization=self.gpu_memory_utilization,
1301
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1302
1303
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1304
            is_attention_free=model_config.is_attention_free,
1305
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1306
            sliding_window=sliding_window,
1307
            enable_prefix_caching=self.enable_prefix_caching,
1308
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1309
            cpu_offload_gb=self.cpu_offload_gb,
1310
            calculate_kv_scales=self.calculate_kv_scales,
1311
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1312
1313
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1314
        )
1315

1316
1317
1318
1319
1320
1321
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1322

1323
1324
1325
            ray_runtime_env = ray.get_runtime_context().runtime_env
            logger.info("Using ray runtime env: %s", ray_runtime_env)

1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1337
        assert not headless or not self.data_parallel_hybrid_lb, (
1338
1339
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1340

1341
        data_parallel_external_lb = self.data_parallel_rank is not None
1342
        # Local DP rank = 1, use pure-external LB.
1343
1344
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
1345
1346
                "data_parallel_size_local must be 1 when data_parallel_rank is set"
            )
1347
            data_parallel_size_local = 1
1348
1349
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1350
1351
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

            self.data_parallel_rank = self.data_parallel_start_rank or 0
1367
        else:
1368
            assert not self.data_parallel_hybrid_lb, (
1369
1370
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1371

1372
1373
            # Local DP size defaults to global DP size if not set.
            data_parallel_size_local = self.data_parallel_size
1374
1375
1376

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1377
1378
1379
1380
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1381
1382
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1383
1384
1385
1386
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1387
1388
                    self.data_parallel_backend,
                )
Rui Qiao's avatar
Rui Qiao committed
1389
1390
1391
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1392
1393
1394

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1395
        data_parallel_rpc_port = (
1396
            self.data_parallel_rpc_port
1397
1398
1399
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1400

1401
1402
1403
1404
        if self.async_scheduling:
            # Async scheduling does not work with the uniprocess backend.
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "mp"
1405
1406
1407
1408
                logger.info(
                    "Defaulting to mp-based distributed executor "
                    "backend for async scheduling."
                )
1409
            if self.pipeline_parallel_size > 1:
1410
1411
1412
                raise ValueError(
                    "Async scheduling is not supported with pipeline-parallel-size > 1."
                )
1413
1414
1415
1416
1417
1418

            # Currently, async scheduling does not support speculative decoding.
            # TODO(woosuk): Support it.
            if self.speculative_config is not None:
                raise ValueError(
                    "Currently, speculative decoding is not supported with "
1419
1420
                    "async scheduling."
                )
1421

1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
        # Forward the deprecated CLI args to the EPLB config.
        if self.num_redundant_experts is not None:
            self.eplb_config.num_redundant_experts = self.num_redundant_experts
        if self.eplb_window_size is not None:
            self.eplb_config.window_size = self.eplb_window_size
        if self.eplb_step_interval is not None:
            self.eplb_config.step_interval = self.eplb_step_interval
        if self.eplb_log_balancedness is not None:
            self.eplb_config.log_balancedness = self.eplb_log_balancedness

1432
        parallel_config = ParallelConfig(
1433
1434
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1435
            data_parallel_size=self.data_parallel_size,
1436
1437
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1438
1439
1440
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1441
            data_parallel_backend=self.data_parallel_backend,
1442
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1443
            enable_expert_parallel=self.enable_expert_parallel,
1444
1445
            enable_dbo=self.enable_dbo,
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1446
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1447
            disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
1448
            enable_eplb=self.enable_eplb,
1449
            eplb_config=self.eplb_config,
1450
            expert_placement_strategy=self.expert_placement_strategy,
1451
1452
1453
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1454
            ray_runtime_env=ray_runtime_env,
1455
            placement_group=placement_group,
1456
1457
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1458
            worker_extension_cls=self.worker_extension_cls,
1459
            decode_context_parallel_size=self.decode_context_parallel_size,
1460
1461
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1462
        )
1463

1464
        speculative_config = self.create_speculative_config(
1465
1466
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1467
            enable_chunked_prefill=self.enable_chunked_prefill,
1468
            disable_log_stats=self.disable_log_stats,
1469
1470
        )

1471
1472
1473
1474
1475
        # make sure num_lookahead_slots is set appropriately depending on
        # whether speculative decoding is enabled
        num_lookahead_slots = self.num_lookahead_slots
        if speculative_config is not None:
            num_lookahead_slots = speculative_config.num_lookahead_slots
1476

1477
        scheduler_config = SchedulerConfig(
1478
            runner_type=model_config.runner_type,
1479
1480
1481
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1482
            cuda_graph_sizes=self.cuda_graph_sizes,
1483
            num_lookahead_slots=num_lookahead_slots,
1484
            enable_chunked_prefill=self.enable_chunked_prefill,
1485
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1486
            is_multimodal_model=model_config.is_multimodal_model,
1487
            is_encoder_decoder=model_config.is_encoder_decoder,
1488
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray),
1489
            policy=self.scheduling_policy,
1490
            scheduler_cls=self.scheduler_cls,
1491
1492
1493
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1494
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1495
            async_scheduling=self.async_scheduling,
1496
        )
1497

1498
1499
1500
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                bias_enabled=self.enable_lora_bias,
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_extra_vocab_size=self.lora_extra_vocab_size,
                lora_dtype=self.lora_dtype,
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1520

1521
1522
1523
1524
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1525
        load_config = self.create_load_config()
1526

1527
1528
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1529
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1530
1531
1532
1533

        # Forward the deprecated CLI args to the StructuredOutputsConfig
        so_config = self.structured_outputs_config
        if self.guided_decoding_backend is not None:
1534
            so_config.guided_decoding_backend = self.guided_decoding_backend
1535
        if self.guided_decoding_disable_fallback is not None:
1536
1537
1538
            so_config.guided_decoding_disable_fallback = (
                self.guided_decoding_disable_fallback
            )
1539
        if self.guided_decoding_disable_any_whitespace is not None:
1540
1541
1542
            so_config.guided_decoding_disable_any_whitespace = (
                self.guided_decoding_disable_any_whitespace
            )
1543
        if self.guided_decoding_disable_additional_properties is not None:
1544
1545
1546
            so_config.guided_decoding_disable_additional_properties = (
                self.guided_decoding_disable_additional_properties
            )
1547

1548
        observability_config = ObservabilityConfig(
1549
            show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version),
1550
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1551
            collect_detailed_traces=self.collect_detailed_traces,
1552
        )
1553

1554
        config = VllmConfig(
1555
1556
1557
1558
1559
1560
1561
1562
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
1563
            structured_outputs_config=self.structured_outputs_config,
1564
            observability_config=observability_config,
1565
            compilation_config=self.compilation_config,
1566
            kv_transfer_config=self.kv_transfer_config,
1567
            kv_events_config=self.kv_events_config,
1568
            additional_config=self.additional_config,
1569
        )
1570

1571
1572
        return config

1573
1574
1575
1576
1577
1578
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1579
1580
1581
1582
        if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
            _raise_or_fallback(
                feature_name="--logits-processor-pattern", recommend_to_remove=False
            )
1583
1584
            return False

1585
        # No Mamba or Encoder-Decoder so far.
1586
        if not model_config.is_v1_compatible:
1587
1588
1589
            _raise_or_fallback(
                feature_name=model_config.architectures, recommend_to_remove=False
            )
1590
1591
1592
            return False

        # No Concurrent Partial Prefills so far.
1593
1594
1595
1596
1597
1598
1599
1600
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
            _raise_or_fallback(
                feature_name="Concurrent Partial Prefill", recommend_to_remove=False
            )
1601
1602
            return False

1603
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
        if self.speculative_config is not None:
            # speculative_config could still be a dict at this point
            if isinstance(self.speculative_config, dict):
                method = self.speculative_config.get("method", None)
            else:
                method = self.speculative_config.method

            if method == "draft_model":
                raise NotImplementedError(
                    "Draft model speculative decoding is not supported yet. "
                    "Please consider using other speculative decoding methods "
1615
1616
                    "such as ngram, medusa, eagle, or mtp."
                )
1617
1618

        V1_BACKENDS = [
1619
1620
            "FLASH_ATTN",
            "PALLAS",
1621
            "TRITON_ATTN",
1622
            "TRITON_MLA",
1623
            "CUTLASS_MLA",
1624
            "FLASHMLA",
1625
            "FLASH_ATTN_MLA",
1626
            "FLASHINFER",
1627
            "FLASHINFER_MLA",
1628
            "ROCM_AITER_MLA",
1629
            "TORCH_SDPA",
1630
            "FLEX_ATTENTION",
1631
            "TREE_ATTN",
1632
1633
            "XFORMERS",
            "ROCM_ATTN",
1634
            "ROCM_AITER_UNIFIED_ATTN",
1635
        ]
1636
1637
1638
1639
        if (
            envs.is_set("VLLM_ATTENTION_BACKEND")
            and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS
        ):
1640
1641
1642
1643
1644
1645
1646
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

        #############################################################
        # Experimental Features - allow users to opt in.

1647
        if self.pipeline_parallel_size > 1:
1648
1649
1650
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1651
            if not supports_pp and self.distributed_executor_backend not in (
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
                _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1663
                return False
1664

1665
1666
1667
1668
        if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
            _raise_or_fallback(
                feature_name="sliding window (CPU backend)", recommend_to_remove=False
            )
1669
1670
            return False

1671
1672
1673
1674
        #############################################################

        return True

1675
1676
1677
    def _set_default_args(
        self, usage_context: UsageContext, model_config: ModelConfig
    ) -> None:
1678
        """Set Default Arguments for V1 Engine."""
1679

1680
1681
1682
1683
1684
        # V1 always uses chunked prefills and prefix caching
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
1685
1686
1687

            # TODO: When prefix caching supports prompt embeds inputs, this
            # check can be removed.
1688
            if self.enable_prompt_embeds and self.enable_prefix_caching is not False:
1689
1690
1691
                logger.warning(
                    "--enable-prompt-embeds and --enable-prefix-caching "
                    "are not supported together in V1. Prefix caching has "
1692
1693
                    "been disabled."
                )
1694
1695
                self.enable_prefix_caching = False

1696
            if self.enable_prefix_caching is None:
1697
1698
1699
1700
1701
1702
                # Disable prefix caching default for hybrid models
                # since the feature is still experimental.
                if model_config.is_hybrid:
                    self.enable_prefix_caching = False
                else:
                    self.enable_prefix_caching = True
1703
1704
        else:
            pooling_type = model_config.pooler_config.pooling_type
1705
            is_causal = getattr(model_config.hf_config, "is_causal", True)
1706
1707
1708
1709
1710
            incremental_prefill_supported = (
                pooling_type is not None
                and pooling_type.lower() == "last"
                and is_causal
            )
1711

1712
            action = "Enabling" if incremental_prefill_supported else "Disabling"
1713
1714
1715
1716
1717
1718
1719
1720

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)

1721
1722
1723
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1724
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1725

1726
1727
        # When no user override, set the default values based on the usage
        # context.
1728
        # Use different default values for different hardware.
1729
1730
1731
1732
1733
1734
1735

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1736
            device_memory = current_platform.get_device_total_memory()
1737
            device_name = current_platform.get_device_name().lower()
1738
1739
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1740
            device_memory = 0
1741

1742
1743
1744
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
1745
        from vllm.usage.usage_lib import UsageContext
1746

1747
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1748
            # For GPUs like H100 and MI300x, use larger default values.
1749
1750
1751
1752
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1753
1754
1755
1756
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1757
1758
1759
1760
1761
1762
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1763
1764
1765
1766
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1767

1768
1769
1770
1771
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
1772
1773
1774
                    "V6E": 2048,
                    "V5E": 1024,
                    "V5P": 512,
1775
1776
                },
                UsageContext.OPENAI_API_SERVER: {
1777
1778
1779
1780
                    "V6E": 1024,
                    "V5E": 512,
                    "V5P": 256,
                },
1781
1782
            }

1783
1784
        # cpu specific default values.
        if current_platform.is_cpu():
1785
            world_size = self.pipeline_parallel_size * self.tensor_parallel_size
1786
            default_max_num_batched_tokens = {
1787
1788
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1789
1790
            }
            default_max_num_seqs = {
1791
1792
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1793
1794
            }

1795
        use_context_value = usage_context.value if usage_context else None
1796
1797
1798
1799
        if (
            self.max_num_batched_tokens is None
            and usage_context in default_max_num_batched_tokens
        ):
1800
1801
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
1802
1803
1804
1805
                if chip_name in default_max_num_batched_tokens_tpu[usage_context]:
                    self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[
                        usage_context
                    ][chip_name]
1806
                else:
1807
1808
1809
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1810
            else:
1811
1812
1813
                if not self.enable_chunked_prefill:
                    self.max_num_batched_tokens = model_config.max_model_len
                else:
1814
1815
1816
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1817
            logger.debug(
1818
                "Setting max_num_batched_tokens to %d for %s usage context.",
1819
1820
1821
                self.max_num_batched_tokens,
                use_context_value,
            )
1822

1823
1824
1825
1826
1827
        if self.max_num_seqs is None and usage_context in default_max_num_seqs:
            self.max_num_seqs = min(
                default_max_num_seqs[usage_context],
                self.max_num_batched_tokens or sys.maxsize,
            )
1828

1829
1830
1831
1832
1833
            logger.debug(
                "Setting max_num_seqs to %d for %s usage context.",
                self.max_num_seqs,
                use_context_value,
            )
1834

1835

1836
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1837
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1838
    """Arguments for asynchronous vLLM engine."""
1839

1840
1841
1842
1843
1844
1845
    enable_log_requests: bool = False

    @property
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1846
1847
        "`enable_log_requests` instead."
    )
1848
1849
1850
1851
1852
1853
1854
    def disable_log_requests(self) -> bool:
        return not self.enable_log_requests

    @disable_log_requests.setter
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1855
1856
        "`enable_log_requests` instead."
    )
1857
1858
    def disable_log_requests(self, value: bool):
        self.enable_log_requests = not value
1859
1860

    @staticmethod
1861
1862
1863
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
1864
        # Initialize plugin to update the parser, for example, The plugin may
1865
        # add a new kind of quantization method to --quantization argument or
1866
1867
        # a new device to --device argument.
        load_general_plugins()
1868
1869
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
1883
        current_platform.pre_register_and_update(parser)
1884
        return parser
1885
1886


1887
1888
1889
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
1890
1891
            f"VLLM_USE_V1=1 is not supported with {feature_name}."
        )
1892
1893
1894
1895
1896
1897
1898
1899
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


1900
1901
1902
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1903

1904
1905
1906
1907
1908
1909
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
1910
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
1911
1912
    if match:
        decimal_multiplier = {
1913
1914
1915
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
1916
1917
        }
        binary_multiplier = {
1918
1919
1920
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
1933
1934
1935
1936
1937
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
1938
1939
1940

    # Regular plain number.
    return int(value)