arg_utils.py 79.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# yapf: disable
5
import argparse
6
import copy
7
import dataclasses
8
import functools
9
import json
10
import sys
11
from dataclasses import MISSING, dataclass, fields, is_dataclass
12
from itertools import permutations
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Type,
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
29

30
import huggingface_hub
31
import regex as re
32
import torch
33
from pydantic import TypeAdapter, ValidationError
34
from typing_extensions import TypeIs, deprecated
35

36
import vllm.envs as envs
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from vllm.config import (
    BlockSize,
    CacheConfig,
    CacheDType,
    CompilationConfig,
    ConfigType,
    ConvertOption,
    DetailedTraceModules,
    Device,
    DeviceConfig,
    DistributedExecutorBackend,
    EPLBConfig,
    HfOverrides,
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LogprobsMode,
    LoRAConfig,
    MambaDType,
    MMEncoderTPMode,
    ModelConfig,
    ModelDType,
    ObservabilityConfig,
    ParallelConfig,
    PoolerConfig,
    PrefixCachingHashAlgo,
    RunnerOption,
    SchedulerConfig,
    SchedulerPolicy,
    SpeculativeConfig,
    StructuredOutputsConfig,
    TaskOption,
    TokenizerMode,
    VllmConfig,
    get_attr_docs,
)
73
from vllm.config.multimodal import MMCacheType, MultiModalConfig
74
from vllm.config.parallel import ExpertPlacementStrategy
75
from vllm.config.utils import get_field
76
from vllm.logger import init_logger
77
from vllm.platforms import CpuArchEnum, current_platform
78
from vllm.plugins import load_general_plugins
79
from vllm.ray.lazy_utils import is_ray_initialized
80
from vllm.reasoning import ReasoningParserManager
81
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
82
83
84
85
86
from vllm.transformers_utils.config import (
    get_model_path,
    is_interleaved,
    maybe_override_with_speculators,
)
87
from vllm.transformers_utils.utils import check_gguf_file
88
from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor
89
from vllm.v1.sample.logits_processor import LogitsProcessor
90
91

# yapf: enable
92

93
94
95
if TYPE_CHECKING:
    from vllm.executor.executor_base import ExecutorBase
    from vllm.model_executor.layers.quantization import QuantizationMethods
96
    from vllm.model_executor.model_loader import LoadFormats
97
98
99
100
    from vllm.usage.usage_lib import UsageContext
else:
    ExecutorBase = Any
    QuantizationMethods = Any
101
    LoadFormats = Any
102
103
    UsageContext = Any

104
105
logger = init_logger(__name__)

106
107
108
109
110
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]

111

112
113
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
114
115
116
117
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
118
119
                f"Value {val} cannot be converted to {return_type}."
            ) from e
120

121
122
123
    return _parse_type


124
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
125
126
127
128
129
    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

130
    return _optional_type
131
132


133
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
134
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
135
        return str(val)
136
    return optional_type(json.loads)(val)
137
138


139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


154
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
155
156
157
158
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
159
    type_hint = get_type(type_hints, Literal)
160
161
162
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
163
        raise ValueError(
164
            "All options must be of the same type. "
165
166
            f"Got {options} with types {[type(c) for c in options]}"
        )
167
168
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
169
170


171
172
173
174
175
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
    elif origin is Union:
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


193
194
195
196
def is_online_quantization(quantization: Any) -> bool:
    return quantization in ["inc"]


197
NEEDS_HELP = (
198
199
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
200
201
202
203
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


204
205
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
206
207
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
208
209
    kwargs = {}
    for field in fields(cls):
210
        # Get the set of possible types for the field
211
        type_hints: set[TypeHint] = get_type_hints(field.type)
212
213
214
215
216

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

217
        # Get the default value of the field
218
219
220
        if field.default is not MISSING:
            default = field.default
        elif field.default_factory is not MISSING:
221
            default = field.default_factory()
222
223
224

        # Get the help text for the field
        name = field.name
225
        help = cls_docs.get(name, "").strip()
226
227
228
229
230
231
232
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
233
234
235
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
236
        if dataclass_cls is not None:
237
238
239
240
241
242
243
244

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
245
            kwargs[name]["help"] += f"\n\n{json_tip}"
246
        elif contains_type(type_hints, bool):
247
248
249
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
250
            kwargs[name].update(literal_to_kwargs(type_hints))
251
252
253
254
255
256
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
257
258
                f"type. Got {types}."
            )
259
260
261
262
263
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
264
265
266
267
268
269
            list_type = types[0]
            if get_origin(list_type) is Union:
                msg = "List type must contain str if it is a Union."
                assert str in get_args(list_type), msg
                list_type = str
            kwargs[name]["type"] = list_type
270
271
272
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
273
            # Special case for large integers
274
275
276
277
278
279
            human_readable_ints = {
                "max_model_len",
                "max_num_batched_tokens",
                "kv_cache_memory_bytes",
            }
            if name in human_readable_ints:
280
                kwargs[name]["type"] = human_readable_int
281
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
282
283
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
284
285
286
287
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
288
            kwargs[name]["type"] = union_dict_and_str
289
        elif contains_type(type_hints, dict):
290
            kwargs[name]["type"] = parse_type(json.loads)
291
            kwargs[name]["help"] += f"\n\n{json_tip}"
292
293
294
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
295
296
            kwargs[name]["type"] = str
        else:
297
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
298

299
300
301
302
303
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

304
305
306
307
308
309
310
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
311
312


313
314
315
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    """Return argparse kwargs for the given Config dataclass.

316
317
318
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

319
320
321
322
323
324
325
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


326
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
327
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
328
    """Arguments for vLLM engine."""
329

330
    model: str = ModelConfig.model
331
    served_model_name: Optional[Union[str, List[str]]] = ModelConfig.served_model_name
332
333
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
334
335
336
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
    task: Optional[TaskOption] = ModelConfig.task
337
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
338
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
339
340
341
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
342
    allowed_media_domains: Optional[list[str]] = ModelConfig.allowed_media_domains
343
    download_dir: Optional[str] = LoadConfig.download_dir
344
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
345
    load_format: Union[str, LoadFormats] = LoadConfig.load_format
346
347
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
348
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
349
350
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
351
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes")
352
353
354
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
355
356
357
    distributed_executor_backend: Optional[
        Union[str, DistributedExecutorBackend, Type[ExecutorBase]]
    ] = ParallelConfig.distributed_executor_backend
358
    # number of P/D disaggregation (or other disaggregation) workers
359
360
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
361
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
362
    data_parallel_size: int = ParallelConfig.data_parallel_size
363
    data_parallel_rank: Optional[int] = None
364
    data_parallel_start_rank: Optional[int] = None
365
366
367
    data_parallel_size_local: Optional[int] = None
    data_parallel_address: Optional[str] = None
    data_parallel_rpc_port: Optional[int] = None
368
    data_parallel_hybrid_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
369
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
370
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
371
    enable_dbo: bool = ParallelConfig.enable_dbo
372
373
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
374
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
375
    enable_eplb: bool = ParallelConfig.enable_eplb
376
    expert_placement_strategy: ExpertPlacementStrategy = (
377
        ParallelConfig.expert_placement_strategy
378
    )
379
380
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
381
382
383
384
    num_redundant_experts: int = EPLBConfig.num_redundant_experts
    eplb_window_size: int = EPLBConfig.window_size
    eplb_step_interval: int = EPLBConfig.step_interval
    eplb_log_balancedness: bool = EPLBConfig.log_balancedness
385
386
387
    max_parallel_loading_workers: Optional[int] = (
        ParallelConfig.max_parallel_loading_workers
    )
388
389
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
390
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
391
        CacheConfig.prefix_caching_hash_algo
392
    )
393
394
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
395
396
397
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
398
    kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
399
    max_num_batched_tokens: Optional[int] = SchedulerConfig.max_num_batched_tokens
400
401
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
402
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
403
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
404
    max_logprobs: int = ModelConfig.max_logprobs
405
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
406
    disable_log_stats: bool = False
407
408
409
410
411
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
412
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
413
414
415
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
416
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
417
418
419
    limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = get_field(
        MultiModalConfig, "limit_per_prompt"
    )
420
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
421
422
423
424
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
    mm_processor_kwargs: Optional[Dict[str, Any]] = MultiModalConfig.mm_processor_kwargs
425
    disable_mm_preprocessor_cache: bool = False  # DEPRECATED
426
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
427
    mm_processor_cache_type: Optional[MMCacheType] = (
428
        MultiModalConfig.mm_processor_cache_type
429
430
    )
    mm_shm_cache_max_object_size_mb: int = (
431
        MultiModalConfig.mm_shm_cache_max_object_size_mb
432
    )
433
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
434
    io_processor_plugin: Optional[str] = None
435
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
436
    video_pruning_rate: float = MultiModalConfig.video_pruning_rate
437
    # LoRA fields
438
    enable_lora: bool = False
439
440
441
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
442
    default_mm_loras: Optional[Dict[str, str]] = LoRAConfig.default_mm_loras
443
444
445
446
447
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size

448
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
449
    num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
450
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
451
452
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
    ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns
453

454
    enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
455
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
456

457
    disable_hybrid_kv_cache_manager: bool = (
458
459
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
460

461
    structured_outputs_config: StructuredOutputsConfig = get_field(
462
463
        VllmConfig, "structured_outputs_config"
    )
464
465
466
467
468
469
470
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
    # Deprecated guided decoding fields
    guided_decoding_backend: Optional[str] = None
    guided_decoding_disable_fallback: Optional[bool] = None
    guided_decoding_disable_any_whitespace: Optional[bool] = None
    guided_decoding_disable_additional_properties: Optional[bool] = None

471
    logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern
472

473
    speculative_config: Optional[Dict[str, Any]] = None
474

475
    show_hidden_metrics_for_version: Optional[str] = (
476
        ObservabilityConfig.show_hidden_metrics_for_version
477
478
479
    )
    otlp_traces_endpoint: Optional[str] = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: Optional[list[DetailedTraceModules]] = (
480
        ObservabilityConfig.collect_detailed_traces
481
    )
482
483
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
484

485
    pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
486
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = (
487
        ModelConfig.override_pooler_config
488
489
    )
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
490
491
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
492

493
    kv_transfer_config: Optional[KVTransferConfig] = None
494
    kv_events_config: Optional[KVEventsConfig] = None
495

496
497
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
498
499
500
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
501
    model_impl: str = ModelConfig.model_impl
502
    override_attention_dtype: str = ModelConfig.override_attention_dtype
503

504
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
505
506
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
507

508
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
509

510
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
511
    pt_load_map_location: str = LoadConfig.pt_load_map_location
512

513
514
    # DEPRECATED
    enable_multimodal_encoder_data_parallel: bool = False
515

516
517
518
    logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = (
        ModelConfig.logits_processors
    )
519
520
    """Custom logitproc types"""

521
522
    async_scheduling: bool = SchedulerConfig.async_scheduling

523
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
524

525
    def __post_init__(self):
526
527
528
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
529
        if isinstance(self.compilation_config, dict):
530
            self.compilation_config = CompilationConfig(**self.compilation_config)
531
        if isinstance(self.eplb_config, dict):
532
            self.eplb_config = EPLBConfig(**self.eplb_config)
533
        # Setup plugins
534
        from vllm.plugins import load_general_plugins
535

536
        load_general_plugins()
537
538
539
540
541
        # when use hf offline,replace model id to local model path
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
            logger.info(
542
543
544
545
                "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                model_id,
                self.model,
            )
546
547

    @staticmethod
548
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
549
        """Shared CLI arguments for vLLM engine."""
550

551
        # Model arguments
552
553
554
555
556
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
557
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
558
            model_group.add_argument("--model", **model_kwargs["model"])
559
560
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
561
        model_group.add_argument("--task", **model_kwargs["task"], deprecated=True)
562
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
563
564
565
566
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
567
568
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
569
570
571
572
573
574
575
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
576
        model_group.add_argument("--revision", **model_kwargs["revision"])
577
578
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"])
579
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
604
605
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--override-pooler-config",
            **model_kwargs["override_pooler_config"],
            deprecated=True,
        )
        model_group.add_argument(
            "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
        )
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
633
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
634
635
636
637
638
639
640
641
642
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
        )
643

644
645
646
647
648
649
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
650
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
651
652
653
654
655
656
657
658
659
660
661
662
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
663

664
665
666
667
668
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
669
        )
670
        structured_outputs_group.add_argument(
671
            "--reasoning-parser",
672
            # This choice is a special case because it's not static
673
            choices=list(ReasoningParserManager.reasoning_parsers),
674
675
            **structured_outputs_kwargs["reasoning_parser"],
        )
676
677
678
679
680
681
682
683
684
685
686
        # Deprecated guided decoding arguments
        for arg, type in [
            ("--guided-decoding-backend", str),
            ("--guided-decoding-disable-fallback", bool),
            ("--guided-decoding-disable-any-whitespace", bool),
            ("--guided-decoding-disable-additional-properties", bool),
        ]:
            structured_outputs_group.add_argument(
                arg,
                type=type,
                help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
687
688
                deprecated=True,
            )
689

690
        # Parallel arguments
691
692
693
694
695
696
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
697
            "--distributed-executor-backend",
698
699
            **parallel_kwargs["distributed_executor_backend"],
        )
700
        parallel_group.add_argument(
701
702
703
704
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
705
        parallel_group.add_argument(
706
707
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
708
        parallel_group.add_argument(
709
710
711
712
713
714
715
716
717
718
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
719
            type=int,
720
721
722
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
723
        parallel_group.add_argument(
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
754
        parallel_group.add_argument(
755
756
757
758
759
760
            "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
        )
        parallel_group.add_argument(
            "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
        )
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
761
762
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
763
764
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
765
766
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
767
768
769
770
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
771
772
        parallel_group.add_argument(
            "--expert-placement-strategy",
773
774
            **parallel_kwargs["expert_placement_strategy"],
        )
775
776
777
        parallel_group.add_argument(
            "--num-redundant-experts",
            type=int,
778
779
780
            help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
            deprecated=True,
        )
781
782
783
784
        parallel_group.add_argument(
            "--eplb-window-size",
            type=int,
            help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
785
786
            deprecated=True,
        )
787
788
789
        parallel_group.add_argument(
            "--eplb-step-interval",
            type=int,
790
791
792
            help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
            deprecated=True,
        )
793
794
795
        parallel_group.add_argument(
            "--eplb-log-balancedness",
            action=argparse.BooleanOptionalAction,
796
797
798
            help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
            deprecated=True,
        )
799

800
        parallel_group.add_argument(
801
            "--max-parallel-loading-workers",
802
803
            **parallel_kwargs["max_parallel_loading_workers"],
        )
804
        parallel_group.add_argument(
805
806
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
807
        parallel_group.add_argument(
808
            "--disable-custom-all-reduce",
809
810
811
812
813
814
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
815
816
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
817
            action="store_true",
818
819
            deprecated=True,
        )
820

821
822
823
824
825
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
826
        )
827
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
828
829
830
831
832
833
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
834
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
            "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"]
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
858

859
        # Multimodal related configs
860
861
862
863
864
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
865
        multimodal_group.add_argument(
866
867
868
869
870
871
872
873
874
875
876
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
877
        multimodal_group.add_argument(
878
879
            "--disable-mm-preprocessor-cache", action="store_true", deprecated=True
        )
880
        multimodal_group.add_argument(
881
882
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
883
884
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
885
886
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
887
        multimodal_group.add_argument(
888
889
890
891
892
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
893
        multimodal_group.add_argument(
894
895
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
896

897
        multimodal_group.add_argument(
898
899
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
900

901
        # LoRA related configs
902
903
904
905
906
907
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
908
            "--enable-lora",
909
            action=argparse.BooleanOptionalAction,
910
911
912
            help="If True, enable handling of LoRA adapters.",
        )
        lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"])
913
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
914
915
916
917
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
        lora_group.add_argument(
            "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"]
        )
918
        lora_group.add_argument(
919
            "--lora-dtype",
920
921
            **lora_kwargs["lora_dtype"],
        )
922
923
924
925
926
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
927

928
929
930
931
932
933
934
935
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
936
937
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
938
        observability_group.add_argument(
939
940
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
941
942
943
944
945
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
946
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
947
948
949
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
950
951
            **observability_kwargs["collect_detailed_traces"],
        )
952

953
954
955
956
957
958
959
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
960
961
            "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"]
        )
962
        scheduler_group.add_argument(
963
964
965
966
967
            "--max-num-seqs", **scheduler_kwargs["max_num_seqs"]
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
968
969
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
970
971
972
973
974
            **scheduler_kwargs["max_long_partial_prefills"],
        )
        scheduler_group.add_argument(
            "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"]
        )
975
976
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
977
978
979
980
981
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
        scheduler_group.add_argument(
            "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"]
        )
982
983
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
984
        scheduler_group.add_argument(
985
986
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
987
        scheduler_group.add_argument(
988
989
990
991
992
993
994
995
            "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"]
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
996
997
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
998
999
1000
1001
1002
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1003
1004

        # vLLM arguments
1005
        vllm_kwargs = get_kwargs(VllmConfig)
1006
1007
1008
1009
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1010
1011
1012
1013
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument(
            "--compilation-config", "-O", **vllm_kwargs["compilation_config"]
        )
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1030

1031
        # Other arguments
1032
1033
1034
1035
1036
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1037

1038
        return parser
1039
1040

    @classmethod
1041
    def from_cli_args(cls, args: argparse.Namespace):
1042
1043
1044
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1045
1046
1047
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1048
        return engine_args
1049

1050
    def create_model_config(self) -> ModelConfig:
1051
1052
1053
1054
1055
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
1056
1057
1058
1059
1060
1061
        if (
            not isinstance(self, AsyncEngineArgs)
            and envs.VLLM_CI_USE_S3
            and self.model in MODELS_ON_S3
            and self.load_format == "auto"
        ):
1062
1063
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"

1064
1065
1066
1067
        if self.disable_mm_preprocessor_cache:
            logger.warning(
                "`--disable-mm-preprocessor-cache` is deprecated "
                "and will be removed in v0.13. "
1068
1069
                "Please use `--mm-processor-cache-gb 0` instead.",
            )
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081

            self.mm_processor_cache_gb = 0
        elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
            logger.warning(
                "VLLM_MM_INPUT_CACHE_GIB` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-processor-cache-gb %d` instead.",
                envs.VLLM_MM_INPUT_CACHE_GIB,
            )

            self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB

1082
1083
1084
1085
        if self.enable_multimodal_encoder_data_parallel:
            logger.warning(
                "--enable-multimodal-encoder-data-parallel` is deprecated "
                "and will be removed in v0.13. "
1086
1087
                "Please use `--mm-encoder-tp-mode data` instead."
            )
1088
1089
1090

            self.mm_encoder_tp_mode = "data"

1091
        return ModelConfig(
1092
            model=self.model,
1093
            hf_config_path=self.hf_config_path,
1094
1095
            runner=self.runner,
            convert=self.convert,
1096
            task=self.task,
1097
            tokenizer=self.tokenizer,
1098
1099
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1100
            allowed_local_media_path=self.allowed_local_media_path,
1101
            allowed_media_domains=self.allowed_media_domains,
1102
1103
1104
1105
1106
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1107
            rope_theta=self.rope_theta,
1108
            hf_token=self.hf_token,
1109
            hf_overrides=self.hf_overrides,
1110
1111
1112
1113
1114
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_logprobs=self.max_logprobs,
1115
            logprobs_mode=self.logprobs_mode,
1116
            disable_sliding_window=self.disable_sliding_window,
1117
            disable_cascade_attn=self.disable_cascade_attn,
1118
            skip_tokenizer_init=self.skip_tokenizer_init,
1119
            enable_prompt_embeds=self.enable_prompt_embeds,
1120
            served_model_name=self.served_model_name,
1121
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1122
            interleave_mm_strings=self.interleave_mm_strings,
1123
            media_io_kwargs=self.media_io_kwargs,
1124
            skip_mm_profiling=self.skip_mm_profiling,
1125
            config_format=self.config_format,
1126
            mm_processor_kwargs=self.mm_processor_kwargs,
1127
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1128
            mm_processor_cache_type=self.mm_processor_cache_type,
1129
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1130
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1131
            pooler_config=self.pooler_config,
1132
            override_pooler_config=self.override_pooler_config,
1133
            logits_processor_pattern=self.logits_processor_pattern,
1134
            generation_config=self.generation_config,
1135
            override_generation_config=self.override_generation_config,
1136
            enable_sleep_mode=self.enable_sleep_mode,
1137
            model_impl=self.model_impl,
1138
            override_attention_dtype=self.override_attention_dtype,
1139
            logits_processors=self.logits_processors,
1140
            video_pruning_rate=self.video_pruning_rate,
1141
            io_processor_plugin=self.io_processor_plugin,
1142
        )
1143

1144
    def validate_tensorizer_args(self):
1145
1146
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1147
1148
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1149
1150
1151
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1152

1153
    def create_load_config(self) -> LoadConfig:
1154
1155
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1156

1157
1158
1159
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1160
1161
                    self.model_loader_extra_config.to_serializable()
                )
1162
            self.model_loader_extra_config["tensorizer_config"] = {}
1163
1164
1165
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1166
            self.validate_tensorizer_args()
1167

1168
1169
1170
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1171
            safetensors_load_strategy=self.safetensors_load_strategy,
1172
            device="cpu" if is_online_quantization(self.quantization) else None,
1173
1174
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1175
            use_tqdm_on_load=self.use_tqdm_on_load,
1176
            pt_load_map_location=self.pt_load_map_location,
1177
        )
1178

1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1192
        dictionary from the engine.
1193
1194
        """
        if self.speculative_config is None:
1195
            return None
1196

1197
1198
1199
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1200
1201
1202
1203
1204
1205
1206
1207
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
                "enable_chunked_prefill": enable_chunked_prefill,
                "disable_log_stats": disable_log_stats,
            }
        )
1208
        return SpeculativeConfig(**self.speculative_config)
1209

1210
1211
1212
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
1213
        headless: bool = False,
1214
1215
1216
1217
1218
1219
1220
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1221

1222
1223
1224
1225
1226
1227
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1228
        current_platform.pre_register_and_update()
1229

1230
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1231

1232
1233
1234
1235
        model_config = self.create_model_config()
        self.model = model_config.model
        self.tokenizer = model_config.tokenizer

1236
1237
1238
1239
1240
1241
1242
1243
1244
        (self.model, self.tokenizer, self.speculative_config) = (
            maybe_override_with_speculators(
                model=self.model,
                tokenizer=self.tokenizer,
                revision=self.revision,
                trust_remote_code=self.trust_remote_code,
                vllm_speculative_config=self.speculative_config,
            )
        )
1245

1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

1263
1264
        # Set default arguments for V1 Engine.
        self._set_default_args(usage_context, model_config)
1265
        # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.S390X,
            CpuArchEnum.ARM,
            CpuArchEnum.RISCV,
        ):
            logger.info(
                "Chunked prefill is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
1277
            self.enable_chunked_prefill = False
1278
1279
        assert self.enable_chunked_prefill is not None

1280
1281
1282
1283
1284
1285
1286
        sliding_window: Optional[int] = None
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1287
1288
1289
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1290
        # reuses the GPUs of TP group, and split one TP group into
1291
        # tp_size//dcp_size DCP groups.
1292
        assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, (
1293
1294
1295
1296
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1297
        cache_config = CacheConfig(
1298
            block_size=self.block_size,
1299
            gpu_memory_utilization=self.gpu_memory_utilization,
1300
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1301
1302
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1303
            is_attention_free=model_config.is_attention_free,
1304
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1305
            sliding_window=sliding_window,
1306
            enable_prefix_caching=self.enable_prefix_caching,
1307
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1308
            cpu_offload_gb=self.cpu_offload_gb,
1309
            calculate_kv_scales=self.calculate_kv_scales,
1310
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1311
1312
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1313
        )
1314

1315
1316
1317
1318
1319
1320
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1321

1322
1323
1324
            ray_runtime_env = ray.get_runtime_context().runtime_env
            logger.info("Using ray runtime env: %s", ray_runtime_env)

1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1336
        assert not headless or not self.data_parallel_hybrid_lb, (
1337
1338
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1339

1340
        data_parallel_external_lb = self.data_parallel_rank is not None
1341
        # Local DP rank = 1, use pure-external LB.
1342
1343
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
1344
1345
                "data_parallel_size_local must be 1 when data_parallel_rank is set"
            )
1346
            data_parallel_size_local = 1
1347
1348
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1349
1350
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

            self.data_parallel_rank = self.data_parallel_start_rank or 0
1366
        else:
1367
            assert not self.data_parallel_hybrid_lb, (
1368
1369
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1370

1371
1372
            # Local DP size defaults to global DP size if not set.
            data_parallel_size_local = self.data_parallel_size
1373
1374
1375

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1376
1377
1378
1379
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1380
1381
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1382
1383
1384
1385
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1386
1387
                    self.data_parallel_backend,
                )
Rui Qiao's avatar
Rui Qiao committed
1388
1389
1390
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1391
1392
1393

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1394
        data_parallel_rpc_port = (
1395
            self.data_parallel_rpc_port
1396
1397
1398
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1399

1400
1401
1402
1403
        if self.async_scheduling:
            # Async scheduling does not work with the uniprocess backend.
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "mp"
1404
1405
1406
1407
                logger.info(
                    "Defaulting to mp-based distributed executor "
                    "backend for async scheduling."
                )
1408
            if self.pipeline_parallel_size > 1:
1409
1410
1411
                raise ValueError(
                    "Async scheduling is not supported with pipeline-parallel-size > 1."
                )
1412
1413
1414
1415
1416
1417

            # Currently, async scheduling does not support speculative decoding.
            # TODO(woosuk): Support it.
            if self.speculative_config is not None:
                raise ValueError(
                    "Currently, speculative decoding is not supported with "
1418
1419
                    "async scheduling."
                )
1420

1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
        # Forward the deprecated CLI args to the EPLB config.
        if self.num_redundant_experts is not None:
            self.eplb_config.num_redundant_experts = self.num_redundant_experts
        if self.eplb_window_size is not None:
            self.eplb_config.window_size = self.eplb_window_size
        if self.eplb_step_interval is not None:
            self.eplb_config.step_interval = self.eplb_step_interval
        if self.eplb_log_balancedness is not None:
            self.eplb_config.log_balancedness = self.eplb_log_balancedness

1431
        parallel_config = ParallelConfig(
1432
1433
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1434
            data_parallel_size=self.data_parallel_size,
1435
1436
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1437
1438
1439
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1440
            data_parallel_backend=self.data_parallel_backend,
1441
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1442
            enable_expert_parallel=self.enable_expert_parallel,
1443
1444
            enable_dbo=self.enable_dbo,
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1445
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1446
            enable_eplb=self.enable_eplb,
1447
            eplb_config=self.eplb_config,
1448
            expert_placement_strategy=self.expert_placement_strategy,
1449
1450
1451
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1452
            ray_runtime_env=ray_runtime_env,
1453
            placement_group=placement_group,
1454
1455
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1456
            worker_extension_cls=self.worker_extension_cls,
1457
            decode_context_parallel_size=self.decode_context_parallel_size,
1458
1459
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1460
        )
1461

1462
        speculative_config = self.create_speculative_config(
1463
1464
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1465
            enable_chunked_prefill=self.enable_chunked_prefill,
1466
            disable_log_stats=self.disable_log_stats,
1467
1468
        )

1469
1470
1471
1472
1473
        # make sure num_lookahead_slots is set appropriately depending on
        # whether speculative decoding is enabled
        num_lookahead_slots = self.num_lookahead_slots
        if speculative_config is not None:
            num_lookahead_slots = speculative_config.num_lookahead_slots
1474

1475
        scheduler_config = SchedulerConfig(
1476
            runner_type=model_config.runner_type,
1477
1478
1479
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1480
            cuda_graph_sizes=self.cuda_graph_sizes,
1481
            num_lookahead_slots=num_lookahead_slots,
1482
            enable_chunked_prefill=self.enable_chunked_prefill,
1483
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1484
            is_multimodal_model=model_config.is_multimodal_model,
1485
            is_encoder_decoder=model_config.is_encoder_decoder,
1486
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray),
1487
            policy=self.scheduling_policy,
1488
            scheduler_cls=self.scheduler_cls,
1489
1490
1491
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1492
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1493
            async_scheduling=self.async_scheduling,
1494
        )
1495

1496
1497
1498
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                bias_enabled=self.enable_lora_bias,
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_extra_vocab_size=self.lora_extra_vocab_size,
                lora_dtype=self.lora_dtype,
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1518

1519
1520
1521
1522
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1523
        load_config = self.create_load_config()
1524

1525
1526
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1527
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1528
1529
1530
1531

        # Forward the deprecated CLI args to the StructuredOutputsConfig
        so_config = self.structured_outputs_config
        if self.guided_decoding_backend is not None:
1532
            so_config.guided_decoding_backend = self.guided_decoding_backend
1533
        if self.guided_decoding_disable_fallback is not None:
1534
1535
1536
            so_config.guided_decoding_disable_fallback = (
                self.guided_decoding_disable_fallback
            )
1537
        if self.guided_decoding_disable_any_whitespace is not None:
1538
1539
1540
            so_config.guided_decoding_disable_any_whitespace = (
                self.guided_decoding_disable_any_whitespace
            )
1541
        if self.guided_decoding_disable_additional_properties is not None:
1542
1543
1544
            so_config.guided_decoding_disable_additional_properties = (
                self.guided_decoding_disable_additional_properties
            )
1545

1546
        observability_config = ObservabilityConfig(
1547
            show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version),
1548
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1549
            collect_detailed_traces=self.collect_detailed_traces,
1550
        )
1551

1552
        config = VllmConfig(
1553
1554
1555
1556
1557
1558
1559
1560
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
1561
            structured_outputs_config=self.structured_outputs_config,
1562
            observability_config=observability_config,
1563
            compilation_config=self.compilation_config,
1564
            kv_transfer_config=self.kv_transfer_config,
1565
            kv_events_config=self.kv_events_config,
1566
            additional_config=self.additional_config,
1567
        )
1568

1569
1570
        return config

1571
1572
1573
1574
1575
1576
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1577
1578
1579
1580
        if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
            _raise_or_fallback(
                feature_name="--logits-processor-pattern", recommend_to_remove=False
            )
1581
1582
            return False

1583
        # No Mamba or Encoder-Decoder so far.
1584
        if not model_config.is_v1_compatible:
1585
1586
1587
            _raise_or_fallback(
                feature_name=model_config.architectures, recommend_to_remove=False
            )
1588
1589
1590
            return False

        # No Concurrent Partial Prefills so far.
1591
1592
1593
1594
1595
1596
1597
1598
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
            _raise_or_fallback(
                feature_name="Concurrent Partial Prefill", recommend_to_remove=False
            )
1599
1600
            return False

1601
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
        if self.speculative_config is not None:
            # speculative_config could still be a dict at this point
            if isinstance(self.speculative_config, dict):
                method = self.speculative_config.get("method", None)
            else:
                method = self.speculative_config.method

            if method == "draft_model":
                raise NotImplementedError(
                    "Draft model speculative decoding is not supported yet. "
                    "Please consider using other speculative decoding methods "
1613
1614
                    "such as ngram, medusa, eagle, or mtp."
                )
1615
1616

        V1_BACKENDS = [
1617
1618
            "FLASH_ATTN",
            "PALLAS",
1619
            "TRITON_ATTN",
1620
            "TRITON_MLA",
1621
            "CUTLASS_MLA",
1622
            "FLASHMLA",
1623
            "FLASH_ATTN_MLA",
1624
            "FLASHINFER",
1625
            "FLASHINFER_MLA",
1626
            "ROCM_AITER_MLA",
1627
            "TORCH_SDPA",
1628
            "FLEX_ATTENTION",
1629
            "TREE_ATTN",
1630
1631
            "XFORMERS",
            "ROCM_ATTN",
1632
        ]
1633
1634
1635
1636
        if (
            envs.is_set("VLLM_ATTENTION_BACKEND")
            and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS
        ):
1637
1638
1639
1640
1641
1642
1643
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

        #############################################################
        # Experimental Features - allow users to opt in.

1644
        if self.pipeline_parallel_size > 1:
1645
1646
1647
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1648
            if not supports_pp and self.distributed_executor_backend not in (
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
                _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1660
                return False
1661

1662
1663
1664
1665
        if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
            _raise_or_fallback(
                feature_name="sliding window (CPU backend)", recommend_to_remove=False
            )
1666
1667
            return False

1668
1669
1670
1671
        #############################################################

        return True

1672
1673
1674
    def _set_default_args(
        self, usage_context: UsageContext, model_config: ModelConfig
    ) -> None:
1675
        """Set Default Arguments for V1 Engine."""
1676

1677
1678
1679
1680
1681
        # V1 always uses chunked prefills and prefix caching
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
1682
1683
1684

            # TODO: When prefix caching supports prompt embeds inputs, this
            # check can be removed.
1685
            if self.enable_prompt_embeds and self.enable_prefix_caching is not False:
1686
1687
1688
                logger.warning(
                    "--enable-prompt-embeds and --enable-prefix-caching "
                    "are not supported together in V1. Prefix caching has "
1689
1690
                    "been disabled."
                )
1691
1692
                self.enable_prefix_caching = False

1693
            if self.enable_prefix_caching is None:
1694
1695
1696
1697
1698
1699
                # Disable prefix caching default for hybrid models
                # since the feature is still experimental.
                if model_config.is_hybrid:
                    self.enable_prefix_caching = False
                else:
                    self.enable_prefix_caching = True
1700
1701
        else:
            pooling_type = model_config.pooler_config.pooling_type
1702
            is_causal = getattr(model_config.hf_config, "is_causal", True)
1703
1704
1705
1706
1707
            incremental_prefill_supported = (
                pooling_type is not None
                and pooling_type.lower() == "last"
                and is_causal
            )
1708

1709
            action = "Enabling" if incremental_prefill_supported else "Disabling"
1710
1711
1712
1713
1714
1715
1716
1717

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)

1718
1719
1720
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1721
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1722

1723
1724
        # When no user override, set the default values based on the usage
        # context.
1725
        # Use different default values for different hardware.
1726
1727
1728
1729
1730
1731
1732

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1733
            device_memory = current_platform.get_device_total_memory()
1734
            device_name = current_platform.get_device_name().lower()
1735
1736
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1737
            device_memory = 0
1738

1739
1740
1741
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
1742
        from vllm.usage.usage_lib import UsageContext
1743

1744
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1745
            # For GPUs like H100 and MI300x, use larger default values.
1746
1747
1748
1749
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1750
1751
1752
1753
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1754
1755
1756
1757
1758
1759
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1760
1761
1762
1763
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1764

1765
1766
1767
1768
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
1769
1770
1771
                    "V6E": 2048,
                    "V5E": 1024,
                    "V5P": 512,
1772
1773
                },
                UsageContext.OPENAI_API_SERVER: {
1774
1775
1776
1777
                    "V6E": 1024,
                    "V5E": 512,
                    "V5P": 256,
                },
1778
1779
            }

1780
1781
        # cpu specific default values.
        if current_platform.is_cpu():
1782
            world_size = self.pipeline_parallel_size * self.tensor_parallel_size
1783
            default_max_num_batched_tokens = {
1784
1785
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1786
1787
            }
            default_max_num_seqs = {
1788
1789
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1790
1791
            }

1792
        use_context_value = usage_context.value if usage_context else None
1793
1794
1795
1796
        if (
            self.max_num_batched_tokens is None
            and usage_context in default_max_num_batched_tokens
        ):
1797
1798
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
1799
1800
1801
1802
                if chip_name in default_max_num_batched_tokens_tpu[usage_context]:
                    self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[
                        usage_context
                    ][chip_name]
1803
                else:
1804
1805
1806
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1807
            else:
1808
1809
1810
                if not self.enable_chunked_prefill:
                    self.max_num_batched_tokens = model_config.max_model_len
                else:
1811
1812
1813
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1814
            logger.debug(
1815
                "Setting max_num_batched_tokens to %d for %s usage context.",
1816
1817
1818
                self.max_num_batched_tokens,
                use_context_value,
            )
1819

1820
1821
1822
1823
1824
        if self.max_num_seqs is None and usage_context in default_max_num_seqs:
            self.max_num_seqs = min(
                default_max_num_seqs[usage_context],
                self.max_num_batched_tokens or sys.maxsize,
            )
1825

1826
1827
1828
1829
1830
            logger.debug(
                "Setting max_num_seqs to %d for %s usage context.",
                self.max_num_seqs,
                use_context_value,
            )
1831

1832

1833
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1834
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1835
    """Arguments for asynchronous vLLM engine."""
1836

1837
1838
1839
1840
1841
1842
    enable_log_requests: bool = False

    @property
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1843
1844
        "`enable_log_requests` instead."
    )
1845
1846
1847
1848
1849
1850
1851
    def disable_log_requests(self) -> bool:
        return not self.enable_log_requests

    @disable_log_requests.setter
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1852
1853
        "`enable_log_requests` instead."
    )
1854
1855
    def disable_log_requests(self, value: bool):
        self.enable_log_requests = not value
1856
1857

    @staticmethod
1858
1859
1860
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
1861
        # Initialize plugin to update the parser, for example, The plugin may
1862
        # add a new kind of quantization method to --quantization argument or
1863
1864
        # a new device to --device argument.
        load_general_plugins()
1865
1866
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
1880
        current_platform.pre_register_and_update(parser)
1881
        return parser
1882
1883


1884
1885
1886
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
1887
1888
            f"VLLM_USE_V1=1 is not supported with {feature_name}."
        )
1889
1890
1891
1892
1893
1894
1895
1896
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


1897
1898
1899
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1900

1901
1902
1903
1904
1905
1906
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
1907
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
1908
1909
    if match:
        decimal_multiplier = {
1910
1911
1912
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
1913
1914
        }
        binary_multiplier = {
1915
1916
1917
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
1930
1931
1932
1933
1934
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
1935
1936
1937

    # Regular plain number.
    return int(value)