arg_utils.py 85.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# yapf: disable
5
import os
6
import argparse
7
import copy
8
import dataclasses
9
import functools
10
import json
11
import sys
12
from dataclasses import MISSING, dataclass, fields, is_dataclass
13
from itertools import permutations
14
15
16
from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List,
                    Literal, Optional, Type, TypeVar, Union, cast, get_args,
                    get_origin)
17

18
import huggingface_hub
19
import regex as re
20
import torch
21
from pydantic import TypeAdapter, ValidationError
22
from typing_extensions import TypeIs, deprecated
23

24
import vllm.envs as envs
25
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
26
27
28
                         ConfigType, ConvertOption, DecodingConfig,
                         DetailedTraceModules, Device, DeviceConfig,
                         DistributedExecutorBackend, EPLBConfig,
29
30
                         GuidedDecodingBackend, HfOverrides, KVEventsConfig,
                         KVTransferConfig, LoadConfig, LogprobsMode,
31
32
33
34
35
36
                         LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
                         ModelDType, ModelImpl, MultiModalConfig,
                         ObservabilityConfig, ParallelConfig, PoolerConfig,
                         PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
                         SchedulerPolicy, SpeculativeConfig, TaskOption,
                         TokenizerMode, VllmConfig, get_attr_docs, get_field)
37
from vllm.logger import init_logger
38
from vllm.platforms import CpuArchEnum, current_platform
39
from vllm.plugins import load_general_plugins
40
from vllm.ray.lazy_utils import is_ray_initialized
41
from vllm.reasoning import ReasoningParserManager
42
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
43
from vllm.transformers_utils.config import get_model_path, is_interleaved
44
from vllm.transformers_utils.utils import check_gguf_file
45
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
Rui Qiao's avatar
Rui Qiao committed
46
                        GiB_bytes, get_ip, is_in_ray_actor)
47
from vllm.v1.sample.logits_processor import LogitsProcessor
48
49

# yapf: enable
50

51
52
53
if TYPE_CHECKING:
    from vllm.executor.executor_base import ExecutorBase
    from vllm.model_executor.layers.quantization import QuantizationMethods
54
    from vllm.model_executor.model_loader import LoadFormats
55
56
57
58
    from vllm.usage.usage_lib import UsageContext
else:
    ExecutorBase = Any
    QuantizationMethods = Any
59
    LoadFormats = Any
60
61
    UsageContext = Any

62
63
logger = init_logger(__name__)

64
65
66
67
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
68

69

70
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
71

72
    def _parse_type(val: str) -> T:
73
74
75
76
77
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
                f"Value {val} cannot be converted to {return_type}.") from e
78

79
80
81
82
83
84
85
86
87
88
89
    return _parse_type


def optional_type(
        return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:

    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

90
    return _optional_type
91
92


93
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
94
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
95
        return str(val)
96
    return optional_type(json.loads)(val)
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


114
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
115
116
117
118
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
119
    type_hint = get_type(type_hints, Literal)
120
121
122
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
123
        raise ValueError(
124
125
126
127
            "All options must be of the same type. "
            f"Got {options} with types {[type(c) for c in options]}")
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
128
129


130
131
132
133
134
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
    elif origin is Union:
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


152
153
154
155
def is_online_quantization(quantization: Any) -> bool:
    return quantization in ["inc"]


156
157
158
159
160
161
162
NEEDS_HELP = (
    "--help" in (argv := sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


163
164
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
165
166
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
167
168
    kwargs = {}
    for field in fields(cls):
169
        # Get the set of possible types for the field
170
        type_hints: set[TypeHint] = get_type_hints(field.type)
171
172
173
174
175

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

176
        # Get the default value of the field
177
178
179
        if field.default is not MISSING:
            default = field.default
        elif field.default_factory is not MISSING:
180
181
182
183
            default = field.default_factory()

        # Get the help text for the field
        name = field.name
184
        help = cls_docs.get(name, "").strip()
185
186
187
188
189
190
191
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
192
193
        json_tip = ("Should either be a valid JSON string or JSON keys passed "
                    "individually.")
194
        if dataclass_cls is not None:
195
196
197
198
199
200
201
202

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
203
            kwargs[name]["help"] += f"\n\n{json_tip}"
204
        elif contains_type(type_hints, bool):
205
206
207
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
208
            kwargs[name].update(literal_to_kwargs(type_hints))
209
210
211
212
213
214
215
216
217
218
219
220
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
                f"type. Got {types}.")
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
221
222
223
224
225
226
            list_type = types[0]
            if get_origin(list_type) is Union:
                msg = "List type must contain str if it is a Union."
                assert str in get_args(list_type), msg
                list_type = str
            kwargs[name]["type"] = list_type
227
228
229
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
230
            # Special case for large integers
231
232
233
234
235
236
            human_readable_ints = {
                "max_model_len",
                "max_num_batched_tokens",
                "kv_cache_memory_bytes",
            }
            if name in human_readable_ints:
237
                kwargs[name]["type"] = human_readable_int
238
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
239
240
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
241
242
243
        elif (contains_type(type_hints, dict)
              and (contains_type(type_hints, str)
                   or any(is_not_builtin(th) for th in type_hints))):
244
            kwargs[name]["type"] = union_dict_and_str
245
        elif contains_type(type_hints, dict):
246
            kwargs[name]["type"] = parse_type(json.loads)
247
            kwargs[name]["help"] += f"\n\n{json_tip}"
248
249
250
251
252
253
254
        elif (contains_type(type_hints, str)
              or any(is_not_builtin(th) for th in type_hints)):
            kwargs[name]["type"] = str
        else:
            raise ValueError(
                f"Unsupported type {type_hints} for argument {name}.")

255
256
257
258
259
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

260
261
262
263
264
265
266
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
267
268


269
270
271
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    """Return argparse kwargs for the given Config dataclass.

272
273
274
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

275
276
277
278
279
280
281
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


282
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
283
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
284
    """Arguments for vLLM engine."""
285
286
287
288
289
    model: str = ModelConfig.model
    served_model_name: Optional[Union[
        str, List[str]]] = ModelConfig.served_model_name
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
290
291
292
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
    task: Optional[TaskOption] = ModelConfig.task
293
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
294
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
295
296
297
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
298
    download_dir: Optional[str] = LoadConfig.download_dir
299
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
300
    load_format: Union[str, LoadFormats] = LoadConfig.load_format
301
302
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
303
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
304
305
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
306
307
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig,
                                            "cuda_graph_sizes")
308
309
310
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
311
    distributed_executor_backend: Optional[Union[
312
        str, DistributedExecutorBackend,
313
        Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
314
    # number of P/D disaggregation (or other disaggregation) workers
315
316
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
317
318
    decode_context_parallel_size: int = \
        ParallelConfig.decode_context_parallel_size
319
    data_parallel_size: int = ParallelConfig.data_parallel_size
320
    data_parallel_rank: Optional[int] = None
321
    data_parallel_start_rank: Optional[int] = None
322
323
324
    data_parallel_size_local: Optional[int] = None
    data_parallel_address: Optional[str] = None
    data_parallel_rpc_port: Optional[int] = None
325
    data_parallel_hybrid_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
326
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
327
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
328
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
329
    enable_eplb: bool = ParallelConfig.enable_eplb
330
331
332
333
    num_redundant_experts: int = EPLBConfig.num_redundant_experts
    eplb_window_size: int = EPLBConfig.window_size
    eplb_step_interval: int = EPLBConfig.step_interval
    eplb_log_balancedness: bool = EPLBConfig.log_balancedness
334
335
    max_parallel_loading_workers: Optional[
        int] = ParallelConfig.max_parallel_loading_workers
336
337
338
339
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
    prefix_caching_hash_algo: PrefixCachingHashAlgo = \
        CacheConfig.prefix_caching_hash_algo
340
341
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
342
343
344
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
345
    kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
346
347
348
349
350
351
352
    max_num_batched_tokens: Optional[
        int] = SchedulerConfig.max_num_batched_tokens
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
    long_prefill_token_threshold: int = \
        SchedulerConfig.long_prefill_token_threshold
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
353
    max_logprobs: int = ModelConfig.max_logprobs
354
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
355
    disable_log_stats: bool = False
356
357
358
359
360
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
361
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
362
363
364
365
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
    max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
366
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
367
    limit_mm_per_prompt: dict[str, int] = \
368
        get_field(MultiModalConfig, "limit_per_prompt")
369
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
370
371
372
    media_io_kwargs: dict[str, dict[str,
                                    Any]] = get_field(MultiModalConfig,
                                                      "media_io_kwargs")
373
374
    mm_processor_kwargs: Optional[Dict[str, Any]] = \
        MultiModalConfig.mm_processor_kwargs
375
    disable_mm_preprocessor_cache: bool = False  # DEPRECATED
376
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
377
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
378
    io_processor_plugin: Optional[str] = None
379
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
380
    # LoRA fields
381
    enable_lora: bool = False
382
383
384
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
385
386
    default_mm_loras: Optional[Dict[str, str]] = \
        LoRAConfig.default_mm_loras
387
388
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
zhuwenwen's avatar
zhuwenwen committed
389
    lora_target_modules: Optional[List[str]] = LoRAConfig.lora_target_modules
390
391
392
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size

393
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
394
395
    num_gpu_blocks_override: Optional[
        int] = CacheConfig.num_gpu_blocks_override
396
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
397
398
    model_loader_extra_config: dict = \
        get_field(LoadConfig, "model_loader_extra_config")
399
400
    ignore_patterns: Optional[Union[str,
                                    List[str]]] = LoadConfig.ignore_patterns
401
    preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
402

403
404
405
406
    scheduler_delay_factor: float = SchedulerConfig.delay_factor
    enable_chunked_prefill: Optional[
        bool] = SchedulerConfig.enable_chunked_prefill
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
407

408
409
410
    disable_hybrid_kv_cache_manager: bool = (
        SchedulerConfig.disable_hybrid_kv_cache_manager)

411
412
413
414
415
416
    guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
    guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
    guided_decoding_disable_any_whitespace: bool = \
        DecodingConfig.disable_any_whitespace
    guided_decoding_disable_additional_properties: bool = \
        DecodingConfig.disable_additional_properties
417
418
    logits_processor_pattern: Optional[
        str] = ModelConfig.logits_processor_pattern
419

420
    speculative_config: Optional[Dict[str, Any]] = None
zhuwenwen's avatar
zhuwenwen committed
421
    num_speculative_heads: Optional[int] = None
422

423
424
425
426
427
428
    show_hidden_metrics_for_version: Optional[str] = \
        ObservabilityConfig.show_hidden_metrics_for_version
    otlp_traces_endpoint: Optional[str] = \
        ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: Optional[list[DetailedTraceModules]] = \
        ObservabilityConfig.collect_detailed_traces
429
    disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
430
431
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
432

433
434
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
        ModelConfig.override_pooler_config
435
436
    compilation_config: CompilationConfig = \
        get_field(VllmConfig, "compilation_config")
437
438
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
439

440
    kv_transfer_config: Optional[KVTransferConfig] = None
441
    kv_events_config: Optional[KVEventsConfig] = None
442

443
444
445
446
447
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
    override_generation_config: dict[str, Any] = \
        get_field(ModelConfig, "override_generation_config")
    model_impl: str = ModelConfig.model_impl
448
    override_attention_dtype: str = ModelConfig.override_attention_dtype
449

450
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
451
452
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
453

454
455
    additional_config: dict[str, Any] = \
        get_field(VllmConfig, "additional_config")
456
457
    reasoning_parser: str = DecodingConfig.reasoning_backend

458
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
459
    pt_load_map_location: str = LoadConfig.pt_load_map_location
王敏's avatar
王敏 committed
460

461
462
    # DEPRECATED
    enable_multimodal_encoder_data_parallel: bool = False
王敏's avatar
王敏 committed
463

464
465
466
467
    logits_processors: Optional[list[Union[
        str, type[LogitsProcessor]]]] = ModelConfig.logits_processors
    """Custom logitproc types"""

468
469
    async_scheduling: bool = SchedulerConfig.async_scheduling

470
471
472
    kv_sharing_fast_prefill: bool = \
        CacheConfig.kv_sharing_fast_prefill

473
    def __post_init__(self):
474
475
476
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
477
478
479
        if isinstance(self.compilation_config, dict):
            self.compilation_config = CompilationConfig(
                **self.compilation_config)
480
        if isinstance(self.eplb_config, dict):
481
            self.eplb_config = EPLBConfig(**self.eplb_config)
482
        # Setup plugins
483
484
        from vllm.plugins import load_general_plugins
        load_general_plugins()
485
486
487
488
489
490
491
        # when use hf offline,replace model id to local model path
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
            logger.info(
                "HF_HUB_OFFLINE is True, replace model_id [%s] " \
                "to model_path [%s]",model_id, self.model)
492
493

    @staticmethod
494
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
495
        """Shared CLI arguments for vLLM engine."""
496

497
        # Model arguments
498
499
500
501
502
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
Reid's avatar
Reid committed
503
        if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]):
504
            model_group.add_argument("--model", **model_kwargs["model"])
505
506
507
508
509
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
        model_group.add_argument("--task",
                                 **model_kwargs["task"],
                                 deprecated=True)
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode",
                                 **model_kwargs["tokenizer_mode"])
        model_group.add_argument("--trust-remote-code",
                                 **model_kwargs["trust_remote_code"])
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
        model_group.add_argument("--hf-config-path",
                                 **model_kwargs["hf_config_path"])
        model_group.add_argument("--allowed-local-media-path",
                                 **model_kwargs["allowed_local_media_path"])
        model_group.add_argument("--revision", **model_kwargs["revision"])
        model_group.add_argument("--code-revision",
                                 **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling",
                                 **model_kwargs["rope_scaling"])
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
        model_group.add_argument("--tokenizer-revision",
                                 **model_kwargs["tokenizer_revision"])
        model_group.add_argument("--max-model-len",
                                 **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q",
                                 **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager",
                                 **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-seq-len-to-capture",
                                 **model_kwargs["max_seq_len_to_capture"])
        model_group.add_argument("--max-logprobs",
                                 **model_kwargs["max_logprobs"])
539
        model_group.add_argument("--logprobs-mode",
540
                                 choices=[f.value for f in LogprobsMode],
541
                                 **model_kwargs["logprobs_mode"])
542
543
544
545
546
547
        model_group.add_argument("--disable-sliding-window",
                                 **model_kwargs["disable_sliding_window"])
        model_group.add_argument("--disable-cascade-attn",
                                 **model_kwargs["disable_cascade_attn"])
        model_group.add_argument("--skip-tokenizer-init",
                                 **model_kwargs["skip_tokenizer_init"])
548
549
        model_group.add_argument("--enable-prompt-embeds",
                                 **model_kwargs["enable_prompt_embeds"])
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
        model_group.add_argument("--served-model-name",
                                 **model_kwargs["served_model_name"])
        # This one is a special case because it is the
        # opposite of ModelConfig.use_async_output_proc
        model_group.add_argument(
            "--disable-async-output-proc",
            action="store_true",
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
        model_group.add_argument("--config-format",
                                 **model_kwargs["config_format"])
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
        model_group.add_argument("--hf-token",
                                 type=str,
                                 nargs="?",
                                 const=True,
                                 default=model_kwargs["hf_token"]["default"],
                                 help=model_kwargs["hf_token"]["help"])
        model_group.add_argument("--hf-overrides",
                                 **model_kwargs["hf_overrides"])
        model_group.add_argument("--override-pooler-config",
                                 **model_kwargs["override_pooler_config"])
        model_group.add_argument("--logits-processor-pattern",
                                 **model_kwargs["logits_processor_pattern"])
        model_group.add_argument("--generation-config",
                                 **model_kwargs["generation_config"])
        model_group.add_argument("--override-generation-config",
                                 **model_kwargs["override_generation_config"])
        model_group.add_argument("--enable-sleep-mode",
                                 **model_kwargs["enable_sleep_mode"])
        model_group.add_argument("--model-impl",
                                 choices=[f.value for f in ModelImpl],
                                 **model_kwargs["model_impl"])
585
586
        model_group.add_argument("--override-attention-dtype",
                                 **model_kwargs["override_attention_dtype"])
587
588
        model_group.add_argument("--logits-processors",
                                 **model_kwargs["logits_processors"])
589
590
        model_group.add_argument("--io-processor-plugin",
                                 **model_kwargs["io_processor_plugin"])
591

592
593
594
595
596
597
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
598
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
599
        load_group.add_argument("--download-dir",
600
                                **load_kwargs["download_dir"])
601
602
        load_group.add_argument("--safetensors-load-strategy",
                                **load_kwargs["safetensors_load_strategy"])
603
        load_group.add_argument("--model-loader-extra-config",
604
                                **load_kwargs["model_loader_extra_config"])
605
606
607
        load_group.add_argument("--ignore-patterns",
                                **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load",
608
                                **load_kwargs["use_tqdm_on_load"])
609
610
        load_group.add_argument('--pt-load-map-location',
                                **load_kwargs["pt_load_map_location"])
611
612
613
614
615
616
617

        # Guided decoding arguments
        guided_decoding_kwargs = get_kwargs(DecodingConfig)
        guided_decoding_group = parser.add_argument_group(
            title="DecodingConfig",
            description=DecodingConfig.__doc__,
        )
618
619
        guided_decoding_group.add_argument("--guided-decoding-backend",
                                           **guided_decoding_kwargs["backend"])
620
        guided_decoding_group.add_argument(
621
622
            "--guided-decoding-disable-fallback",
            **guided_decoding_kwargs["disable_fallback"])
623
        guided_decoding_group.add_argument(
624
625
626
627
628
            "--guided-decoding-disable-any-whitespace",
            **guided_decoding_kwargs["disable_any_whitespace"])
        guided_decoding_group.add_argument(
            "--guided-decoding-disable-additional-properties",
            **guided_decoding_kwargs["disable_additional_properties"])
629
630
        guided_decoding_group.add_argument(
            "--reasoning-parser",
631
            # This choice is a special case because it's not static
632
633
634
            choices=list(ReasoningParserManager.reasoning_parsers),
            **guided_decoding_kwargs["reasoning_backend"])

635
        # Parallel arguments
636
637
638
639
640
641
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
642
            "--distributed-executor-backend",
643
644
            **parallel_kwargs["distributed_executor_backend"])
        parallel_group.add_argument(
645
            "--pipeline-parallel-size", "-pp",
646
            **parallel_kwargs["pipeline_parallel_size"])
647
        parallel_group.add_argument("--tensor-parallel-size", "-tp",
648
                                    **parallel_kwargs["tensor_parallel_size"])
649
650
651
        parallel_group.add_argument(
            "--decode-context-parallel-size", "-dcp",
            **parallel_kwargs["decode_context_parallel_size"])
652
        parallel_group.add_argument("--data-parallel-size", "-dp",
653
                                    **parallel_kwargs["data_parallel_size"])
654
655
656
657
658
659
        parallel_group.add_argument(
            '--data-parallel-rank',
            '-dpn',
            type=int,
            help='Data parallel rank of this instance. '
            'When set, enables external load balancer mode.')
660
661
662
663
664
        parallel_group.add_argument('--data-parallel-start-rank',
                                    '-dpr',
                                    type=int,
                                    help='Starting data parallel rank '
                                    'for secondary nodes.')
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
        parallel_group.add_argument('--data-parallel-size-local',
                                    '-dpl',
                                    type=int,
                                    help='Number of data parallel replicas '
                                    'to run on this node.')
        parallel_group.add_argument('--data-parallel-address',
                                    '-dpa',
                                    type=str,
                                    help='Address of data parallel cluster '
                                    'head-node.')
        parallel_group.add_argument('--data-parallel-rpc-port',
                                    '-dpp',
                                    type=int,
                                    help='Port for data parallel RPC '
                                    'communication.')
Rui Qiao's avatar
Rui Qiao committed
680
681
682
683
684
685
        parallel_group.add_argument('--data-parallel-backend',
                                    '-dpb',
                                    type=str,
                                    default='mp',
                                    help='Backend for data parallel, either '
                                    '"mp" or "ray".')
686
687
688
        parallel_group.add_argument(
            "--data-parallel-hybrid-lb",
            **parallel_kwargs["data_parallel_hybrid_lb"])
689
        parallel_group.add_argument(
690
            "--enable-expert-parallel",
691
            **parallel_kwargs["enable_expert_parallel"])
692
693
        parallel_group.add_argument("--enable-eplb",
                                    **parallel_kwargs["enable_eplb"])
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
        parallel_group.add_argument("--eplb-config",
                                    **parallel_kwargs["eplb_config"])
        parallel_group.add_argument(
            "--num-redundant-experts",
            type=int,
            help=
            "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
            deprecated=True)
        parallel_group.add_argument(
            "--eplb-window-size",
            type=int,
            help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
            deprecated=True)
        parallel_group.add_argument(
            "--eplb-step-interval",
            type=int,
            help=
            "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
            deprecated=True)
        parallel_group.add_argument(
            "--eplb-log-balancedness",
            action=argparse.BooleanOptionalAction,
            help=
            "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
            deprecated=True)

720
        parallel_group.add_argument(
721
            "--max-parallel-loading-workers",
722
723
            **parallel_kwargs["max_parallel_loading_workers"])
        parallel_group.add_argument(
724
            "--ray-workers-use-nsight",
725
726
            **parallel_kwargs["ray_workers_use_nsight"])
        parallel_group.add_argument(
727
            "--disable-custom-all-reduce",
728
            **parallel_kwargs["disable_custom_all_reduce"])
729
730
731
732
        parallel_group.add_argument("--worker-cls",
                                    **parallel_kwargs["worker_cls"])
        parallel_group.add_argument("--worker-extension-cls",
                                    **parallel_kwargs["worker_extension_cls"])
733
734
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
735
736
            action="store_true",
            deprecated=True)
737

738
739
740
741
742
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
743
        )
744
745
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
        cache_group.add_argument("--gpu-memory-utilization",
746
                                 **cache_kwargs["gpu_memory_utilization"])
747
748
        cache_group.add_argument("--kv-cache-memory-bytes",
                                 **cache_kwargs["kv_cache_memory_bytes"])
749
750
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
        cache_group.add_argument("--kv-cache-dtype",
751
                                 **cache_kwargs["cache_dtype"])
752
        cache_group.add_argument("--num-gpu-blocks-override",
753
754
755
756
757
                                 **cache_kwargs["num_gpu_blocks_override"])
        cache_group.add_argument("--enable-prefix-caching",
                                 **cache_kwargs["enable_prefix_caching"])
        cache_group.add_argument("--prefix-caching-hash-algo",
                                 **cache_kwargs["prefix_caching_hash_algo"])
758
        cache_group.add_argument("--cpu-offload-gb",
759
                                 **cache_kwargs["cpu_offload_gb"])
760
        cache_group.add_argument("--calculate-kv-scales",
761
                                 **cache_kwargs["calculate_kv_scales"])
762
763
        cache_group.add_argument("--kv-sharing-fast-prefill",
                                 **cache_kwargs["kv_sharing_fast_prefill"])
764
765
766
767
        cache_group.add_argument("--mamba-cache-dtype",
                                 **cache_kwargs["mamba_cache_dtype"])
        cache_group.add_argument("--mamba-ssm-cache-dtype",
                                 **cache_kwargs["mamba_ssm_cache_dtype"])
768

769
        # Multimodal related configs
770
771
772
773
774
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
775
        multimodal_group.add_argument("--limit-mm-per-prompt",
776
                                      **multimodal_kwargs["limit_per_prompt"])
777
778
        multimodal_group.add_argument("--media-io-kwargs",
                                      **multimodal_kwargs["media_io_kwargs"])
779
        multimodal_group.add_argument(
780
            "--mm-processor-kwargs",
781
782
            **multimodal_kwargs["mm_processor_kwargs"])
        multimodal_group.add_argument(
783
784
785
            "--mm-processor-cache-gb",
            **multimodal_kwargs["mm_processor_cache_gb"])
        multimodal_group.add_argument("--disable-mm-preprocessor-cache",
786
                                      action="store_true",
787
                                      deprecated=True)
788
789
        multimodal_group.add_argument(
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"])
790
791
792
        multimodal_group.add_argument(
            "--interleave-mm-strings",
            **multimodal_kwargs["interleave_mm_strings"])
793
794
        multimodal_group.add_argument("--skip-mm-profiling",
                                      **multimodal_kwargs["skip_mm_profiling"])
795

796
        # LoRA related configs
797
798
799
800
801
802
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
803
            "--enable-lora",
804
            action=argparse.BooleanOptionalAction,
805
806
            help="If True, enable handling of LoRA adapters.")
        lora_group.add_argument("--enable-lora-bias",
807
                                **lora_kwargs["bias_enabled"])
808
809
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
        lora_group.add_argument("--max-lora-rank",
810
                                **lora_kwargs["max_lora_rank"])
zhuwenwen's avatar
zhuwenwen committed
811
812
        lora_group.add_argument('--lora-target-modules',
                            **lora_kwargs["lora_target_modules"])
813
        lora_group.add_argument("--lora-extra-vocab-size",
814
815
                                **lora_kwargs["lora_extra_vocab_size"])
        lora_group.add_argument(
816
            "--lora-dtype",
817
818
            **lora_kwargs["lora_dtype"],
        )
819
        lora_group.add_argument("--max-cpu-loras",
820
                                **lora_kwargs["max_cpu_loras"])
821
        lora_group.add_argument("--fully-sharded-loras",
822
                                **lora_kwargs["fully_sharded_loras"])
823
824
        lora_group.add_argument("--default-mm-loras",
                                **lora_kwargs["default_mm_loras"])
825

826

827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
            **observability_kwargs["show_hidden_metrics_for_version"])
        observability_group.add_argument(
            "--otlp-traces-endpoint",
            **observability_kwargs["otlp_traces_endpoint"])
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
            ",".join(p)
            for p in permutations(get_args(DetailedTraceModules), r=2)
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
            **observability_kwargs["collect_detailed_traces"])
850

851
852
853
854
855
856
857
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
858
            "--max-num-batched-tokens",
859
            **scheduler_kwargs["max_num_batched_tokens"])
860
        scheduler_group.add_argument("--max-num-seqs",
861
862
863
864
865
866
867
                                     **scheduler_kwargs["max_num_seqs"])
        scheduler_group.add_argument(
            "--max-num-partial-prefills",
            **scheduler_kwargs["max_num_partial_prefills"])
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
            **scheduler_kwargs["max_long_partial_prefills"])
868
869
        scheduler_group.add_argument('--cuda-graph-sizes',
                                     **scheduler_kwargs["cuda_graph_sizes"])
870
871
872
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
            **scheduler_kwargs["long_prefill_token_threshold"])
873
        scheduler_group.add_argument("--num-lookahead-slots",
874
                                     **scheduler_kwargs["num_lookahead_slots"])
875
        scheduler_group.add_argument("--scheduler-delay-factor",
876
                                     **scheduler_kwargs["delay_factor"])
877
        scheduler_group.add_argument("--preemption-mode",
878
                                     **scheduler_kwargs["preemption_mode"])
879
880
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
881
        scheduler_group.add_argument("--scheduling-policy",
882
                                     **scheduler_kwargs["policy"])
883
        scheduler_group.add_argument(
884
            "--enable-chunked-prefill",
885
            **scheduler_kwargs["enable_chunked_prefill"])
886
887
888
        scheduler_group.add_argument(
            "--disable-chunked-mm-input",
            **scheduler_kwargs["disable_chunked_mm_input"])
889
890
        scheduler_group.add_argument("--scheduler-cls",
                                     **scheduler_kwargs["scheduler_cls"])
891
892
893
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"])
894
895
        scheduler_group.add_argument("--async-scheduling",
                                     **scheduler_kwargs["async_scheduling"])
896
897

        # vLLM arguments
898
        vllm_kwargs = get_kwargs(VllmConfig)
899
900
901
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
902
        )
903
904
905
906
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
907
908
        vllm_group.add_argument("--speculative-config",
                                **vllm_kwargs["speculative_config"])
909
910
911
912
913
914
915
916
        vllm_group.add_argument("--kv-transfer-config",
                                **vllm_kwargs["kv_transfer_config"])
        vllm_group.add_argument('--kv-events-config',
                                **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument("--compilation-config", "-O",
                                **vllm_kwargs["compilation_config"])
        vllm_group.add_argument("--additional-config",
                                **vllm_kwargs["additional_config"])
zhuwenwen's avatar
zhuwenwen committed
917
918
919
920
921
922
923
        
        parser.add_argument(
            '--num-speculative-heads',
            type=int,
            default=EngineArgs.num_speculative_heads,
            help='The number of speculative heads to sample from '
                 'the draft model in speculative decoding.')
924

925
926
927
928
        # Other arguments
        parser.add_argument('--disable-log-stats',
                            action='store_true',
                            help='Disable logging statistics.')
929

930
        return parser
931
932

    @classmethod
933
    def from_cli_args(cls, args: argparse.Namespace):
934
935
936
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
937
938
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
939

940
    def create_model_config(self) -> ModelConfig:
941
942
943
944
945
946
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
947
                and self.model in MODELS_ON_S3 and self.load_format == "auto"):
948
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
949
            self.load_format = "runai_streamer"
950

951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
        if self.disable_mm_preprocessor_cache:
            logger.warning(
                "`--disable-mm-preprocessor-cache` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-processor-cache-gb 0` instead.", )

            self.mm_processor_cache_gb = 0
        elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
            logger.warning(
                "VLLM_MM_INPUT_CACHE_GIB` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-processor-cache-gb %d` instead.",
                envs.VLLM_MM_INPUT_CACHE_GIB,
            )

            self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
967

968
969
970
971
972
973
974
975
        if self.enable_multimodal_encoder_data_parallel:
            logger.warning(
                "--enable-multimodal-encoder-data-parallel` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-encoder-tp-mode data` instead.")

            self.mm_encoder_tp_mode = "data"

976
        return ModelConfig(
977
            model=self.model,
978
            hf_config_path=self.hf_config_path,
979
980
            runner=self.runner,
            convert=self.convert,
981
            task=self.task,
982
            tokenizer=self.tokenizer,
983
984
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
985
            allowed_local_media_path=self.allowed_local_media_path,
986
987
988
989
990
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
991
            rope_theta=self.rope_theta,
992
            hf_token=self.hf_token,
993
            hf_overrides=self.hf_overrides,
994
995
996
997
998
999
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
1000
            logprobs_mode=self.logprobs_mode,
1001
            disable_sliding_window=self.disable_sliding_window,
1002
            disable_cascade_attn=self.disable_cascade_attn,
1003
            skip_tokenizer_init=self.skip_tokenizer_init,
1004
            enable_prompt_embeds=self.enable_prompt_embeds,
1005
            served_model_name=self.served_model_name,
1006
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1007
            interleave_mm_strings=self.interleave_mm_strings,
1008
            media_io_kwargs=self.media_io_kwargs,
1009
            skip_mm_profiling=self.skip_mm_profiling,
1010
            use_async_output_proc=not self.disable_async_output_proc,
1011
            config_format=self.config_format,
1012
            mm_processor_kwargs=self.mm_processor_kwargs,
1013
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1014
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1015
            override_pooler_config=self.override_pooler_config,
1016
            logits_processor_pattern=self.logits_processor_pattern,
1017
            generation_config=self.generation_config,
1018
            override_generation_config=self.override_generation_config,
1019
            enable_sleep_mode=self.enable_sleep_mode,
1020
            model_impl=self.model_impl,
1021
            override_attention_dtype=self.override_attention_dtype,
1022
            logits_processors=self.logits_processors,
1023
            io_processor_plugin=self.io_processor_plugin,
1024
            enable_chunked_prefill=self.enable_chunked_prefill,
1025
        )
1026

1027
1028
1029
1030
1031
1032
1033
    def validate_tensorizer_args(self):
        from vllm.model_executor.model_loader.tensorizer import (
            TensorizerConfig)
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
                self.model_loader_extra_config["tensorizer_config"][
                    key] = self.model_loader_extra_config[key]
1034

1035
1036
    def create_load_config(self) -> LoadConfig:

1037
1038
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1039

1040
1041
1042
1043
1044
1045
1046
1047
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
                    self.model_loader_extra_config.to_serializable())
            self.model_loader_extra_config["tensorizer_config"] = {}
            self.model_loader_extra_config["tensorizer_config"][
                "tensorizer_dir"] = self.model
            self.validate_tensorizer_args()
1048

1049
1050
1051
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1052
            safetensors_load_strategy=self.safetensors_load_strategy,
1053
1054
            device="cpu"
            if is_online_quantization(self.quantization) else None,
1055
1056
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1057
            use_tqdm_on_load=self.use_tqdm_on_load,
1058
            pt_load_map_location=self.pt_load_map_location,
1059
1060
        )

1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1074
        dictionary from the engine.
1075
        """
1076
1077
1078
1079
1080

        from vllm.transformers_utils.config import get_config
        from vllm.transformers_utils.configs.speculators.base import (
            SpeculatorsConfig)

1081
        if self.speculative_config is None:
1082
1083
1084
1085
            hf_config = get_config(
                self.hf_config_path or target_model_config.model,
                self.trust_remote_code, self.revision, self.code_revision,
                self.config_format)
1086

1087
            # if loading a SpeculatorsConfig, load the speculative_config
1088
1089
1090
            # details from the config directly
            # no user input required / expected
            if isinstance(hf_config, SpeculatorsConfig):
1091
                # We create one since we don't create one
1092
1093
1094
                self.speculative_config = {}
                self.speculative_config[
                    "num_speculative_tokens"] = hf_config.num_lookahead_tokens
1095
                self.speculative_config["model"] = target_model_config.model
1096
1097
1098
                self.speculative_config["method"] = hf_config.method
            else:
                return None
1099

1100
1101
1102
1103
1104
1105
1106
1107
1108
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
1109
        return SpeculativeConfig(**self.speculative_config)
1110

1111
1112
1113
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
1114
        headless: bool = False,
1115
1116
1117
    ) -> VllmConfig:
        """
        Create the VllmConfig.
1118

1119
1120
1121
        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
1122

1123
1124
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.
1125

1126
1127
1128
        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1129
        current_platform.pre_register_and_update()
1130

1131
1132
        device_config = DeviceConfig(
            device=cast(Device, current_platform.device_type))
1133
1134
        model_config = self.create_model_config()

1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
1154
            self._set_default_args_v1(usage_context, model_config)
1155
            # Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1
1156
1157
            if current_platform.is_cpu(
            ) and current_platform.get_cpu_architecture() in (
1158
                    CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM):
1159
                logger.info(
1160
1161
                    "Chunked prefill is not supported for ARM and POWER "
                    "and S390X CPUs; "
1162
1163
                    "disabling it for V1 backend.")
                self.enable_chunked_prefill = False
1164
1165
        else:
            self._set_default_args_v0(model_config)
1166
        assert self.enable_chunked_prefill is not None
1167

1168
1169
1170
1171
1172
        if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
            assert self.enforce_eager, (
                "Cuda graph is not supported with DualChunkFlashAttention. "
                "To run the model in eager mode, set 'enforce_eager=True' "
                "or use '--enforce-eager' in the CLI.")
1173
1174
            assert current_platform.is_cuda() or current_platform.is_rocm(), (
                "DualChunkFlashAttention is only supported on CUDA/ROCM platform.")
1175
1176
1177
1178
            assert not use_v1, (
                "DualChunkFlashAttention is not supported on V1 engine. "
                "To run the model in V0 engine, try set 'VLLM_USE_V1=0'")

1179
1180
1181
1182
1183
1184
1185
        sliding_window: Optional[int] = None
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1186
1187
1188
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1189
        # reuses the GPUs of TP group, and split one TP group into
1190
1191
1192
1193
1194
1195
1196
        # tp_size//dcp_size DCP groups.
        assert self.tensor_parallel_size % self.decode_context_parallel_size \
            == 0, (
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1197
        cache_config = CacheConfig(
1198
            block_size=self.block_size,
1199
            gpu_memory_utilization=self.gpu_memory_utilization,
1200
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1201
1202
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1203
            is_attention_free=model_config.is_attention_free,
1204
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1205
            sliding_window=sliding_window,
1206
            enable_prefix_caching=self.enable_prefix_caching,
1207
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1208
            cpu_offload_gb=self.cpu_offload_gb,
1209
            calculate_kv_scales=self.calculate_kv_scales,
1210
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1211
1212
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1213
        )
1214

1215
1216
1217
1218
1219
1220
1221
1222
1223
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
            ray_runtime_env = ray.get_runtime_context().runtime_env
            logger.info("Using ray runtime env: %s", ray_runtime_env)

1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1235
1236
1237
1238
        assert not headless or not self.data_parallel_hybrid_lb, (
            "data_parallel_hybrid_lb is not applicable in "
            "headless mode")

1239
        data_parallel_external_lb = self.data_parallel_rank is not None
1240
        # Local DP rank = 1, use pure-external LB.
1241
1242
1243
1244
1245
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
                "data_parallel_size_local must be 1 when data_parallel_rank "
                "is set")
            data_parallel_size_local = 1
1246
1247
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1248
1249
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

            self.data_parallel_rank = self.data_parallel_start_rank or 0
1265
        else:
1266
1267
1268
1269
            assert not self.data_parallel_hybrid_lb, (
                "data_parallel_size_local must be set to use "
                "data_parallel_hybrid_lb.")

1270
1271
            # Local DP size defaults to global DP size if not set.
            data_parallel_size_local = self.data_parallel_size
1272
1273
1274

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
                    "Using host IP %s as ray-based data parallel address",
                    host_ip)
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
                    self.data_parallel_backend)
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1289
1290
1291
1292
1293
1294
1295

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
        data_parallel_rpc_port = self.data_parallel_rpc_port if (
            self.data_parallel_rpc_port
            is not None) else ParallelConfig.data_parallel_rpc_port

1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
        if self.async_scheduling:
            # Async scheduling does not work with the uniprocess backend.
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "mp"
                logger.info("Using mp-based distributed executor backend "
                            "for async scheduling.")
            if self.distributed_executor_backend == "uni":
                raise ValueError("Async scheduling is not supported with "
                                 "uni-process backend.")
            if self.pipeline_parallel_size > 1:
                raise ValueError("Async scheduling is not supported with "
                                 "pipeline-parallel-size > 1.")

            # Currently, async scheduling does not support speculative decoding.
            # TODO(woosuk): Support it.
            if self.speculative_config is not None:
                raise ValueError(
                    "Currently, speculative decoding is not supported with "
                    "async scheduling.")

1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
        # Forward the deprecated CLI args to the EPLB config.
        if self.num_redundant_experts is not None:
            self.eplb_config.num_redundant_experts = self.num_redundant_experts
        if self.eplb_window_size is not None:
            self.eplb_config.window_size = self.eplb_window_size
        if self.eplb_step_interval is not None:
            self.eplb_config.step_interval = self.eplb_step_interval
        if self.eplb_log_balancedness is not None:
            self.eplb_config.log_balancedness = self.eplb_log_balancedness

1326
        parallel_config = ParallelConfig(
1327
1328
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1329
            data_parallel_size=self.data_parallel_size,
1330
1331
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1332
1333
1334
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1335
            data_parallel_backend=self.data_parallel_backend,
1336
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1337
            enable_expert_parallel=self.enable_expert_parallel,
1338
            enable_eplb=self.enable_eplb,
1339
            eplb_config=self.eplb_config,
1340
1341
1342
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1343
            ray_runtime_env=ray_runtime_env,
1344
            placement_group=placement_group,
1345
1346
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1347
            worker_extension_cls=self.worker_extension_cls,
1348
            decode_context_parallel_size=self.decode_context_parallel_size,
1349
        )
1350

1351
        speculative_config = self.create_speculative_config(
1352
1353
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1354
            enable_chunked_prefill=self.enable_chunked_prefill,
王敏's avatar
王敏 committed
1355
            disable_log_stats=self.disable_log_stats,
1356
1357
        )

1358
1359
1360
1361
1362
        # make sure num_lookahead_slots is set appropriately depending on
        # whether speculative decoding is enabled
        num_lookahead_slots = self.num_lookahead_slots
        if speculative_config is not None:
            num_lookahead_slots = speculative_config.num_lookahead_slots
1363

1364
        scheduler_config = SchedulerConfig(
1365
            runner_type=model_config.runner_type,
1366
1367
1368
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1369
            cuda_graph_sizes=self.cuda_graph_sizes,
1370
            num_lookahead_slots=num_lookahead_slots,
1371
1372
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1373
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1374
            is_multimodal_model=model_config.is_multimodal_model,
1375
            preemption_mode=self.preemption_mode,
1376
1377
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1378
            policy=self.scheduling_policy,
1379
            scheduler_cls=self.scheduler_cls,
1380
1381
1382
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1383
1384
            disable_hybrid_kv_cache_manager=self.
            disable_hybrid_kv_cache_manager,
1385
            async_scheduling=self.async_scheduling,
1386
        )
1387

1388
1389
1390
1391
1392
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
                "non multimodal model")

1393
        lora_config = LoRAConfig(
1394
            bias_enabled=self.enable_lora_bias,
1395
1396
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1397
            default_mm_loras=self.default_mm_loras,
1398
            fully_sharded_loras=self.fully_sharded_loras,
1399
1400
1401
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
1402
1403
            and self.max_cpu_loras > 0 else None,
            lora_target_modules=self.lora_target_modules) if self.enable_lora else None
1404

1405
1406
1407
1408
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1409
        load_config = self.create_load_config()
1410

1411
        decoding_config = DecodingConfig(
1412
1413
1414
1415
1416
            backend=self.guided_decoding_backend,
            disable_fallback=self.guided_decoding_disable_fallback,
            disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
            disable_additional_properties=\
                self.guided_decoding_disable_additional_properties,
1417
1418
            reasoning_backend=self.reasoning_parser
        )
1419

1420
        observability_config = ObservabilityConfig(
1421
1422
            show_hidden_metrics_for_version=(
                self.show_hidden_metrics_for_version),
1423
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1424
            collect_detailed_traces=self.collect_detailed_traces,
1425
        )
1426

1427
        config = VllmConfig(
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1438
            compilation_config=self.compilation_config,
1439
            kv_transfer_config=self.kv_transfer_config,
1440
            kv_events_config=self.kv_events_config,
1441
            additional_config=self.additional_config,
1442
        )
1443

1444
1445
        return config

1446
1447
1448
1449
1450
1451
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1452
        if self.load_format == "sharded_state":
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

1464
        if self.preemption_mode != SchedulerConfig.preemption_mode:
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

1475
        if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
1476
1477
1478
1479
1480
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

        if self.kv_cache_dtype != "auto":
1481
            supported = current_platform.is_kv_cache_dtype_supported(
1482
                self.kv_cache_dtype, model_config)
xiabo's avatar
xiabo committed
1483
1484
1485
            int8_attention = self.kv_cache_dtype.startswith("int8")
            if int8_attention:
                supported = True
1486
1487
1488
1489
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1490

1491
1492
1493
1494
1495
1496
        # No text embedding inputs so far.
        if self.enable_prompt_embeds:
            _raise_or_fallback(feature_name="--enable-prompt-embeds",
                               recommend_to_remove=False)
            return False

1497
1498
1499
1500
1501
1502
1503
1504
        # No Mamba or Encoder-Decoder so far.
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
1505
                != SchedulerConfig.max_num_partial_prefills
1506
                or self.max_long_partial_prefills
1507
                != SchedulerConfig.max_long_partial_prefills):
1508
1509
1510
1511
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

1512
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1513
1514
1515
1516
1517
1518
        if (self.speculative_config is not None
                and self.speculative_config.get("method") == "draft_model"):
            raise NotImplementedError(
                "Speculative decoding with draft model is not supported yet. "
                "Please consider using other speculative decoding methods "
                "such as ngram, medusa, eagle, or deepseek_mtp.")
1519
1520

        V1_BACKENDS = [
1521
1522
1523
1524
1525
1526
            "FLASH_ATTN_VLLM_V1",
            "FLASH_ATTN",
            "PALLAS",
            "PALLAS_VLLM_V1",
            "TRITON_ATTN_VLLM_V1",
            "TRITON_MLA",
1527
            "CUTLASS_MLA",
1528
            "FLASHMLA",
1529
1530
            "FLASHMLA_VLLM_V1",
            "FLASH_ATTN_MLA",
1531
1532
            "FLASHINFER",
            "FLASHINFER_VLLM_V1",
1533
            "FLASHINFER_MLA",
1534
            "ROCM_AITER_MLA",
1535
            "TORCH_SDPA_VLLM_V1",
1536
            "FLEX_ATTENTION",
1537
            "TREE_ATTN",
1538
            "XFORMERS_VLLM_V1",
1539
1540
1541
1542
1543
1544
1545
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1546
1547
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1548
1549
1550
1551
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1552
1553
1554
        #############################################################
        # Experimental Features - allow users to opt in.

1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
        if self.pipeline_parallel_size > 1:
            supports_pp = getattr(self.distributed_executor_backend,
                                  'supports_pp', False)
            if not supports_pp and self.distributed_executor_backend not in (
                    ParallelConfig.distributed_executor_backend, "ray", "mp",
                    "external_launcher"):
                name = "Pipeline Parallelism without Ray distributed " \
                        "executor or multiprocessing executor or external " \
                        "launcher"
                _raise_or_fallback(feature_name=name,
                                   recommend_to_remove=False)
                return False
1567

1568
1569
1570
1571
        # The platform may be supported on V1, but off by default for now.
        if not current_platform.default_v1(  # noqa: SIM103
                model_config=model_config) and _warn_or_fallback(
                    current_platform.device_name):
1572
            return False
1573
1574
1575
1576
1577
1578
1579

        if (current_platform.is_cpu()
                and model_config.get_sliding_window() is not None):
            _raise_or_fallback(feature_name="sliding window (CPU backend)",
                               recommend_to_remove=False)
            return False

1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1600
                use_spec_decode = self.speculative_config is not None
1601
1602

                if (is_gpu and not use_sliding_window and not use_spec_decode
1603
                        and not self.enable_lora):
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)

1622
1623
1624
1625
1626
1627
        # Disable prefix caching for multimodal models for VLLM_V0.
        if self.enable_prefix_caching and model_config.is_multimodal_model:
            logger.warning(
                "--enable-prefix-caching is not supported for multimodal "
                "models in V0 and has been disabled.")
            self.enable_prefix_caching = False
1628
1629
1630
1631
1632

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

1633
1634
    def _set_default_args_v1(self, usage_context: UsageContext,
                             model_config: ModelConfig) -> None:
1635
        """Set Default Arguments for V1 Engine."""
1636

1637
1638
1639
1640
1641
        # V1 always uses chunked prefills and prefix caching
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
zhuwenwen's avatar
zhuwenwen committed
1642
1643
1644
            if model_config.enable_chunked_prefill is not None and \
                model_config.enable_chunked_prefill is False:
                self.enable_chunked_prefill = False
1645
1646
1647
1648
1649
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = True
        else:

            pooling_type = model_config.pooler_config.pooling_type
1650
1651
1652
1653
            is_causal = getattr(model_config.hf_config, "is_causal", True)
            incremental_prefill_supported = (pooling_type is not None
                                             and pooling_type.lower() == "last"
                                             and is_causal)
1654

1655
1656
            action = "Enabling" if \
                incremental_prefill_supported else "Disabling"
zhuwenwen's avatar
zhuwenwen committed
1657
1658
1659
1660
                
            if model_config.enable_chunked_prefill is not None and \
                model_config.enable_chunked_prefill is False:
                self.enable_chunked_prefill = False
1661

1662
1663
1664
1665
1666
1667
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)
1668

1669
1670
1671
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1672
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1673

1674
1675
        # When no user override, set the default values based on the usage
        # context.
1676
        # Use different default values for different hardware.
1677
1678
1679
1680
1681
1682
1683

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1684
            device_memory = current_platform.get_device_total_memory()
1685
            device_name = current_platform.get_device_name().lower()
1686
1687
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1688
            device_memory = 0
1689

1690
1691
1692
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
1693
        from vllm.usage.usage_lib import UsageContext
1694
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1695
            # For GPUs like H100 and MI300x, use larger default values.
1696
1697
1698
1699
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1700
1701
1702
1703
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1704
1705
1706
1707
1708
1709
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1710
1711
1712
1713
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1714

1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
                    'V6E': 2048,
                    'V5E': 1024,
                    'V5P': 512,
                },
                UsageContext.OPENAI_API_SERVER: {
                    'V6E': 1024,
                    'V5E': 512,
                    'V5P': 256,
                }
            }

1730
1731
        # cpu specific default values.
        if current_platform.is_cpu():
1732
            world_size = self.pipeline_parallel_size * self.tensor_parallel_size
1733
            default_max_num_batched_tokens = {
1734
1735
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1736
1737
            }
            default_max_num_seqs = {
1738
1739
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1740
1741
            }

1742
        use_context_value = usage_context.value if usage_context else None
1743
1744
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
                if chip_name in default_max_num_batched_tokens_tpu[
                        usage_context]:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens_tpu[
                            usage_context][chip_name]
                else:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens[usage_context]
            else:
1756
1757
1758
1759
1760
                if not self.enable_chunked_prefill:
                    self.max_num_batched_tokens = model_config.max_model_len
                else:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens[usage_context]
1761
            logger.debug(
1762
                "Setting max_num_batched_tokens to %d for %s usage context.",
1763
                self.max_num_batched_tokens, use_context_value)
1764

1765
1766
        if (self.max_num_seqs is None
                and usage_context in default_max_num_seqs):
1767
1768
            self.max_num_seqs = min(default_max_num_seqs[usage_context],
                                    self.max_num_batched_tokens or sys.maxsize)
1769
1770
1771

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1772

1773

1774
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1775
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1776
    """Arguments for asynchronous vLLM engine."""
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
    enable_log_requests: bool = False

    @property
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
        "`enable_log_requests` instead.")
    def disable_log_requests(self) -> bool:
        return not self.enable_log_requests

    @disable_log_requests.setter
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
        "`enable_log_requests` instead.")
    def disable_log_requests(self, value: bool):
        self.enable_log_requests = not value
1794
1795

    @staticmethod
1796
1797
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1798
        # Initialize plugin to update the parser, for example, The plugin may
1799
        # add a new kind of quantization method to --quantization argument or
1800
1801
        # a new device to --device argument.
        load_general_plugins()
1802
1803
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1804
1805
1806
1807
        parser.add_argument('--enable-log-requests',
                            action=argparse.BooleanOptionalAction,
                            default=AsyncEngineArgs.enable_log_requests,
                            help='Enable logging requests.')
1808
        parser.add_argument('--disable-log-requests',
1809
1810
1811
1812
                            action=argparse.BooleanOptionalAction,
                            default=not AsyncEngineArgs.enable_log_requests,
                            help='[DEPRECATED] Disable logging requests.',
                            deprecated=True)
1813
        current_platform.pre_register_and_update(parser)
1814
        return parser
1815
1816


1817
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
zhuwenwen's avatar
zhuwenwen committed
1818
1819
    # if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    if envs.VLLM_USE_V1:
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
zhuwenwen's avatar
zhuwenwen committed
1831
1832
    # if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    if envs.VLLM_USE_V1:
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1846
1847
1848
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1849

1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
    match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
    if match:
        decimal_multiplier = {
            'k': 10**3,
            'm': 10**6,
            'g': 10**9,
        }
        binary_multiplier = {
            'K': 2**10,
            'M': 2**20,
            'G': 2**30,
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
                raise argparse.ArgumentTypeError("Decimals are not allowed " \
                f"with binary suffixes like {suffix}. Did you mean to use " \
                f"{number}{suffix.lower()} instead?") from e

    # Regular plain number.
    return int(value)