arg_utils.py 79 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import argparse
5
import copy
6
import dataclasses
7
import functools
8
import json
9
import sys
10
from dataclasses import MISSING, dataclass, fields, is_dataclass
11
from itertools import permutations
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Type,
    TypeVar,
    Union,
    cast,
    get_args,
    get_origin,
)
28

29
import huggingface_hub
30
import regex as re
31
import torch
32
from pydantic import TypeAdapter, ValidationError
33
from typing_extensions import TypeIs, deprecated
34

35
import vllm.envs as envs
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from vllm.config import (
    BlockSize,
    CacheConfig,
    CacheDType,
    CompilationConfig,
    ConfigType,
    ConvertOption,
    DetailedTraceModules,
    Device,
    DeviceConfig,
    DistributedExecutorBackend,
    EPLBConfig,
    HfOverrides,
    KVEventsConfig,
    KVTransferConfig,
    LoadConfig,
    LogprobsMode,
    LoRAConfig,
    MambaDType,
    MMEncoderTPMode,
    ModelConfig,
    ModelDType,
    ObservabilityConfig,
    ParallelConfig,
    PoolerConfig,
    PrefixCachingHashAlgo,
    RunnerOption,
    SchedulerConfig,
    SchedulerPolicy,
    SpeculativeConfig,
    StructuredOutputsConfig,
    TaskOption,
    TokenizerMode,
    VllmConfig,
    get_attr_docs,
)
72
from vllm.config.multimodal import MMCacheType, MultiModalConfig
73
from vllm.config.parallel import ExpertPlacementStrategy
74
from vllm.config.utils import get_field
75
from vllm.logger import init_logger
76
from vllm.platforms import CpuArchEnum, current_platform
77
from vllm.plugins import load_general_plugins
78
from vllm.ray.lazy_utils import is_ray_initialized
79
from vllm.reasoning import ReasoningParserManager
80
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
81
82
83
84
85
from vllm.transformers_utils.config import (
    get_model_path,
    is_interleaved,
    maybe_override_with_speculators,
)
86
from vllm.transformers_utils.utils import check_gguf_file
87
from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor
88
from vllm.v1.sample.logits_processor import LogitsProcessor
89

90
91
92
if TYPE_CHECKING:
    from vllm.executor.executor_base import ExecutorBase
    from vllm.model_executor.layers.quantization import QuantizationMethods
93
    from vllm.model_executor.model_loader import LoadFormats
94
95
96
97
    from vllm.usage.usage_lib import UsageContext
else:
    ExecutorBase = Any
    QuantizationMethods = Any
98
    LoadFormats = Any
99
100
    UsageContext = Any

101
102
logger = init_logger(__name__)

103
104
105
106
107
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]

108

109
110
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
    def _parse_type(val: str) -> T:
111
112
113
114
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
115
116
                f"Value {val} cannot be converted to {return_type}."
            ) from e
117

118
119
120
    return _parse_type


121
def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
122
123
124
125
126
    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

127
    return _optional_type
128
129


130
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
131
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
132
        return str(val)
133
    return optional_type(json.loads)(val)
134
135


136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


151
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
152
153
154
155
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
156
    type_hint = get_type(type_hints, Literal)
157
158
159
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
160
        raise ValueError(
161
            "All options must be of the same type. "
162
163
            f"Got {options} with types {[type(c) for c in options]}"
        )
164
165
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
166
167


168
169
170
171
172
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
    elif origin is Union:
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


190
191
192
193
def is_online_quantization(quantization: Any) -> bool:
    return quantization in ["inc"]


194
NEEDS_HELP = (
195
196
    any("--help" in arg for arg in sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := sys.argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
197
198
199
200
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


201
202
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
203
204
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
205
206
    kwargs = {}
    for field in fields(cls):
207
        # Get the set of possible types for the field
208
        type_hints: set[TypeHint] = get_type_hints(field.type)
209
210
211
212
213

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

214
        # Get the default value of the field
215
216
217
        if field.default is not MISSING:
            default = field.default
        elif field.default_factory is not MISSING:
218
            default = field.default_factory()
219
220
221

        # Get the help text for the field
        name = field.name
222
        help = cls_docs.get(name, "").strip()
223
224
225
226
227
228
229
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
230
231
232
        json_tip = (
            "Should either be a valid JSON string or JSON keys passed individually."
        )
233
        if dataclass_cls is not None:
234
235
236
237
238
239
240
241

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
242
            kwargs[name]["help"] += f"\n\n{json_tip}"
243
        elif contains_type(type_hints, bool):
244
245
246
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
247
            kwargs[name].update(literal_to_kwargs(type_hints))
248
249
250
251
252
253
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
254
255
                f"type. Got {types}."
            )
256
257
258
259
260
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
261
262
263
264
265
266
            list_type = types[0]
            if get_origin(list_type) is Union:
                msg = "List type must contain str if it is a Union."
                assert str in get_args(list_type), msg
                list_type = str
            kwargs[name]["type"] = list_type
267
268
269
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
270
            # Special case for large integers
271
272
273
274
275
276
            human_readable_ints = {
                "max_model_len",
                "max_num_batched_tokens",
                "kv_cache_memory_bytes",
            }
            if name in human_readable_ints:
277
                kwargs[name]["type"] = human_readable_int
278
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
279
280
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
281
282
283
284
        elif contains_type(type_hints, dict) and (
            contains_type(type_hints, str)
            or any(is_not_builtin(th) for th in type_hints)
        ):
285
            kwargs[name]["type"] = union_dict_and_str
286
        elif contains_type(type_hints, dict):
287
            kwargs[name]["type"] = parse_type(json.loads)
288
            kwargs[name]["help"] += f"\n\n{json_tip}"
289
290
291
        elif contains_type(type_hints, str) or any(
            is_not_builtin(th) for th in type_hints
        ):
292
293
            kwargs[name]["type"] = str
        else:
294
            raise ValueError(f"Unsupported type {type_hints} for argument {name}.")
295

296
297
298
299
300
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

301
302
303
304
305
306
307
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
308
309


310
311
312
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    """Return argparse kwargs for the given Config dataclass.

313
314
315
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

316
317
318
319
320
321
322
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


323
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
324
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
325
    """Arguments for vLLM engine."""
326

327
    model: str = ModelConfig.model
328
    served_model_name: Optional[Union[str, List[str]]] = ModelConfig.served_model_name
329
330
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
331
332
333
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
    task: Optional[TaskOption] = ModelConfig.task
334
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
335
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
336
337
338
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
339
    allowed_media_domains: Optional[list[str]] = ModelConfig.allowed_media_domains
340
    download_dir: Optional[str] = LoadConfig.download_dir
341
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
342
    load_format: Union[str, LoadFormats] = LoadConfig.load_format
343
344
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
345
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
346
347
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
348
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes")
349
350
351
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
352
353
354
    distributed_executor_backend: Optional[
        Union[str, DistributedExecutorBackend, Type[ExecutorBase]]
    ] = ParallelConfig.distributed_executor_backend
355
    # number of P/D disaggregation (or other disaggregation) workers
356
357
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
358
    decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
359
    data_parallel_size: int = ParallelConfig.data_parallel_size
360
    data_parallel_rank: Optional[int] = None
361
    data_parallel_start_rank: Optional[int] = None
362
363
364
    data_parallel_size_local: Optional[int] = None
    data_parallel_address: Optional[str] = None
    data_parallel_rpc_port: Optional[int] = None
365
    data_parallel_hybrid_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
366
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
367
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
368
    enable_dbo: bool = ParallelConfig.enable_dbo
369
370
    dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
    dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
371
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
372
    enable_eplb: bool = ParallelConfig.enable_eplb
373
    expert_placement_strategy: ExpertPlacementStrategy = (
374
        ParallelConfig.expert_placement_strategy
375
    )
376
377
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
378
379
380
381
    num_redundant_experts: int = EPLBConfig.num_redundant_experts
    eplb_window_size: int = EPLBConfig.window_size
    eplb_step_interval: int = EPLBConfig.step_interval
    eplb_log_balancedness: bool = EPLBConfig.log_balancedness
382
383
384
    max_parallel_loading_workers: Optional[int] = (
        ParallelConfig.max_parallel_loading_workers
    )
385
386
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
387
    prefix_caching_hash_algo: PrefixCachingHashAlgo = (
388
        CacheConfig.prefix_caching_hash_algo
389
    )
390
391
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
392
393
394
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
395
    kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
396
    max_num_batched_tokens: Optional[int] = SchedulerConfig.max_num_batched_tokens
397
398
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
399
    long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
400
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
401
    max_logprobs: int = ModelConfig.max_logprobs
402
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
403
    disable_log_stats: bool = False
404
405
406
407
408
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
409
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
410
411
412
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
413
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
414
415
416
    limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = get_field(
        MultiModalConfig, "limit_per_prompt"
    )
417
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
418
419
420
421
    media_io_kwargs: dict[str, dict[str, Any]] = get_field(
        MultiModalConfig, "media_io_kwargs"
    )
    mm_processor_kwargs: Optional[Dict[str, Any]] = MultiModalConfig.mm_processor_kwargs
422
    disable_mm_preprocessor_cache: bool = False  # DEPRECATED
423
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
424
    mm_processor_cache_type: Optional[MMCacheType] = (
425
        MultiModalConfig.mm_processor_cache_type
426
427
    )
    mm_shm_cache_max_object_size_mb: int = (
428
        MultiModalConfig.mm_shm_cache_max_object_size_mb
429
    )
430
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
431
    io_processor_plugin: Optional[str] = None
432
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
433
    video_pruning_rate: float = MultiModalConfig.video_pruning_rate
434
    # LoRA fields
435
    enable_lora: bool = False
436
437
438
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
439
    default_mm_loras: Optional[Dict[str, str]] = LoRAConfig.default_mm_loras
440
441
442
443
444
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size

445
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
446
    num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
447
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
448
449
    model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
    ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns
450

451
    enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
452
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
453

454
    disable_hybrid_kv_cache_manager: bool = (
455
456
        SchedulerConfig.disable_hybrid_kv_cache_manager
    )
457

458
    structured_outputs_config: StructuredOutputsConfig = get_field(
459
460
        VllmConfig, "structured_outputs_config"
    )
461
462
463
464
465
466
467
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
    # Deprecated guided decoding fields
    guided_decoding_backend: Optional[str] = None
    guided_decoding_disable_fallback: Optional[bool] = None
    guided_decoding_disable_any_whitespace: Optional[bool] = None
    guided_decoding_disable_additional_properties: Optional[bool] = None

468
    logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern
469

470
    speculative_config: Optional[Dict[str, Any]] = None
471

472
    show_hidden_metrics_for_version: Optional[str] = (
473
        ObservabilityConfig.show_hidden_metrics_for_version
474
475
476
    )
    otlp_traces_endpoint: Optional[str] = ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: Optional[list[DetailedTraceModules]] = (
477
        ObservabilityConfig.collect_detailed_traces
478
    )
479
480
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
481

482
    pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
483
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = (
484
        ModelConfig.override_pooler_config
485
486
    )
    compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
487
488
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
489

490
    kv_transfer_config: Optional[KVTransferConfig] = None
491
    kv_events_config: Optional[KVEventsConfig] = None
492

493
494
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
495
496
497
    override_generation_config: dict[str, Any] = get_field(
        ModelConfig, "override_generation_config"
    )
498
    model_impl: str = ModelConfig.model_impl
499
    override_attention_dtype: str = ModelConfig.override_attention_dtype
500

501
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
502
503
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
504

505
    additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
506

507
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
508
    pt_load_map_location: str = LoadConfig.pt_load_map_location
509

510
511
    # DEPRECATED
    enable_multimodal_encoder_data_parallel: bool = False
512

513
514
515
    logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = (
        ModelConfig.logits_processors
    )
516
517
    """Custom logitproc types"""

518
519
    async_scheduling: bool = SchedulerConfig.async_scheduling

520
    kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
521

522
    def __post_init__(self):
523
524
525
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
526
        if isinstance(self.compilation_config, dict):
527
            self.compilation_config = CompilationConfig(**self.compilation_config)
528
        if isinstance(self.eplb_config, dict):
529
            self.eplb_config = EPLBConfig(**self.eplb_config)
530
        # Setup plugins
531
        from vllm.plugins import load_general_plugins
532

533
        load_general_plugins()
534
535
536
537
538
        # when use hf offline,replace model id to local model path
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
            logger.info(
539
540
541
542
                "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]",
                model_id,
                self.model,
            )
543
544

    @staticmethod
545
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
546
        """Shared CLI arguments for vLLM engine."""
547

548
        # Model arguments
549
550
551
552
553
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
554
        if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]):
555
            model_group.add_argument("--model", **model_kwargs["model"])
556
557
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
558
        model_group.add_argument("--task", **model_kwargs["task"], deprecated=True)
559
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
560
561
562
563
        model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
        model_group.add_argument(
            "--trust-remote-code", **model_kwargs["trust_remote_code"]
        )
564
565
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
566
567
568
569
570
571
572
        model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"])
        model_group.add_argument(
            "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"]
        )
        model_group.add_argument(
            "--allowed-media-domains", **model_kwargs["allowed_media_domains"]
        )
573
        model_group.add_argument("--revision", **model_kwargs["revision"])
574
575
        model_group.add_argument("--code-revision", **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"])
576
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        model_group.add_argument(
            "--tokenizer-revision", **model_kwargs["tokenizer_revision"]
        )
        model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
        model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
        model_group.add_argument(
            "--disable-sliding-window", **model_kwargs["disable_sliding_window"]
        )
        model_group.add_argument(
            "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"]
        )
        model_group.add_argument(
            "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"]
        )
        model_group.add_argument(
            "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"]
        )
        model_group.add_argument(
            "--served-model-name", **model_kwargs["served_model_name"]
        )
        model_group.add_argument("--config-format", **model_kwargs["config_format"])
601
602
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        model_group.add_argument(
            "--hf-token",
            type=str,
            nargs="?",
            const=True,
            default=model_kwargs["hf_token"]["default"],
            help=model_kwargs["hf_token"]["help"],
        )
        model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
        model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
        model_group.add_argument(
            "--override-pooler-config",
            **model_kwargs["override_pooler_config"],
            deprecated=True,
        )
        model_group.add_argument(
            "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
        )
        model_group.add_argument(
            "--generation-config", **model_kwargs["generation_config"]
        )
        model_group.add_argument(
            "--override-generation-config", **model_kwargs["override_generation_config"]
        )
        model_group.add_argument(
            "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"]
        )
630
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
631
632
633
634
635
636
637
638
639
        model_group.add_argument(
            "--override-attention-dtype", **model_kwargs["override_attention_dtype"]
        )
        model_group.add_argument(
            "--logits-processors", **model_kwargs["logits_processors"]
        )
        model_group.add_argument(
            "--io-processor-plugin", **model_kwargs["io_processor_plugin"]
        )
640

641
642
643
644
645
646
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
647
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
648
649
650
651
652
653
654
655
656
657
658
659
        load_group.add_argument("--download-dir", **load_kwargs["download_dir"])
        load_group.add_argument(
            "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"]
        )
        load_group.add_argument(
            "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]
        )
        load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"])
        load_group.add_argument(
            "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
        )
660

661
662
663
664
665
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
666
        )
667
        structured_outputs_group.add_argument(
668
            "--reasoning-parser",
669
            # This choice is a special case because it's not static
670
            choices=list(ReasoningParserManager.reasoning_parsers),
671
672
            **structured_outputs_kwargs["reasoning_parser"],
        )
673
674
675
676
677
678
679
680
681
682
683
        # Deprecated guided decoding arguments
        for arg, type in [
            ("--guided-decoding-backend", str),
            ("--guided-decoding-disable-fallback", bool),
            ("--guided-decoding-disable-any-whitespace", bool),
            ("--guided-decoding-disable-additional-properties", bool),
        ]:
            structured_outputs_group.add_argument(
                arg,
                type=type,
                help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
684
685
                deprecated=True,
            )
686

687
        # Parallel arguments
688
689
690
691
692
693
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
694
            "--distributed-executor-backend",
695
696
            **parallel_kwargs["distributed_executor_backend"],
        )
697
        parallel_group.add_argument(
698
699
700
701
            "--pipeline-parallel-size",
            "-pp",
            **parallel_kwargs["pipeline_parallel_size"],
        )
702
        parallel_group.add_argument(
703
704
            "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
        )
705
        parallel_group.add_argument(
706
707
708
709
710
711
712
713
714
715
            "--decode-context-parallel-size",
            "-dcp",
            **parallel_kwargs["decode_context_parallel_size"],
        )
        parallel_group.add_argument(
            "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
        )
        parallel_group.add_argument(
            "--data-parallel-rank",
            "-dpn",
716
            type=int,
717
718
719
            help="Data parallel rank of this instance. "
            "When set, enables external load balancer mode.",
        )
720
        parallel_group.add_argument(
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
            "--data-parallel-start-rank",
            "-dpr",
            type=int,
            help="Starting data parallel rank for secondary nodes.",
        )
        parallel_group.add_argument(
            "--data-parallel-size-local",
            "-dpl",
            type=int,
            help="Number of data parallel replicas to run on this node.",
        )
        parallel_group.add_argument(
            "--data-parallel-address",
            "-dpa",
            type=str,
            help="Address of data parallel cluster head-node.",
        )
        parallel_group.add_argument(
            "--data-parallel-rpc-port",
            "-dpp",
            type=int,
            help="Port for data parallel RPC communication.",
        )
        parallel_group.add_argument(
            "--data-parallel-backend",
            "-dpb",
            type=str,
            default="mp",
            help='Backend for data parallel, either "mp" or "ray".',
        )
751
        parallel_group.add_argument(
752
753
754
755
756
757
            "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
        )
        parallel_group.add_argument(
            "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
        )
        parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
758
759
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
760
761
            **parallel_kwargs["dbo_decode_token_threshold"],
        )
762
763
        parallel_group.add_argument(
            "--dbo-prefill-token-threshold",
764
765
766
767
            **parallel_kwargs["dbo_prefill_token_threshold"],
        )
        parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"])
768
769
        parallel_group.add_argument(
            "--expert-placement-strategy",
770
771
            **parallel_kwargs["expert_placement_strategy"],
        )
772
773
774
        parallel_group.add_argument(
            "--num-redundant-experts",
            type=int,
775
776
777
            help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
            deprecated=True,
        )
778
779
780
781
        parallel_group.add_argument(
            "--eplb-window-size",
            type=int,
            help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
782
783
            deprecated=True,
        )
784
785
786
        parallel_group.add_argument(
            "--eplb-step-interval",
            type=int,
787
788
789
            help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
            deprecated=True,
        )
790
791
792
        parallel_group.add_argument(
            "--eplb-log-balancedness",
            action=argparse.BooleanOptionalAction,
793
794
795
            help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
            deprecated=True,
        )
796

797
        parallel_group.add_argument(
798
            "--max-parallel-loading-workers",
799
800
            **parallel_kwargs["max_parallel_loading_workers"],
        )
801
        parallel_group.add_argument(
802
803
            "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"]
        )
804
        parallel_group.add_argument(
805
            "--disable-custom-all-reduce",
806
807
808
809
810
811
            **parallel_kwargs["disable_custom_all_reduce"],
        )
        parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"])
        parallel_group.add_argument(
            "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
        )
812
813
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
814
            action="store_true",
815
816
            deprecated=True,
        )
817

818
819
820
821
822
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
823
        )
824
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
825
826
827
828
829
830
        cache_group.add_argument(
            "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
        )
        cache_group.add_argument(
            "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
        )
831
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
        cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"])
        cache_group.add_argument(
            "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
        )
        cache_group.add_argument(
            "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"]
        )
        cache_group.add_argument(
            "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
        )
        cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument(
            "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
        )
        cache_group.add_argument(
            "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
        )
        cache_group.add_argument(
            "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"]
        )
        cache_group.add_argument(
            "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
        )
855

856
        # Multimodal related configs
857
858
859
860
861
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
862
        multimodal_group.add_argument(
863
864
865
866
867
868
869
870
871
872
873
            "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"]
        )
        multimodal_group.add_argument(
            "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]
        )
        multimodal_group.add_argument(
            "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
        )
874
        multimodal_group.add_argument(
875
876
            "--disable-mm-preprocessor-cache", action="store_true", deprecated=True
        )
877
        multimodal_group.add_argument(
878
879
            "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
        )
880
881
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
882
883
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
        )
884
        multimodal_group.add_argument(
885
886
887
888
889
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
        )
        multimodal_group.add_argument(
            "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
        )
890
        multimodal_group.add_argument(
891
892
            "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"]
        )
893

894
        multimodal_group.add_argument(
895
896
            "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]
        )
897

898
        # LoRA related configs
899
900
901
902
903
904
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
905
            "--enable-lora",
906
            action=argparse.BooleanOptionalAction,
907
908
909
            help="If True, enable handling of LoRA adapters.",
        )
        lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"])
910
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
911
912
913
914
        lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"])
        lora_group.add_argument(
            "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"]
        )
915
        lora_group.add_argument(
916
            "--lora-dtype",
917
918
            **lora_kwargs["lora_dtype"],
        )
919
920
921
922
923
        lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument(
            "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
        )
        lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
924

925
926
927
928
929
930
931
932
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
933
934
            **observability_kwargs["show_hidden_metrics_for_version"],
        )
935
        observability_group.add_argument(
936
937
            "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"]
        )
938
939
940
941
942
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
943
            ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2)
944
945
946
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
947
948
            **observability_kwargs["collect_detailed_traces"],
        )
949

950
951
952
953
954
955
956
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
957
958
            "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"]
        )
959
        scheduler_group.add_argument(
960
961
962
963
964
            "--max-num-seqs", **scheduler_kwargs["max_num_seqs"]
        )
        scheduler_group.add_argument(
            "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"]
        )
965
966
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
967
968
969
970
971
            **scheduler_kwargs["max_long_partial_prefills"],
        )
        scheduler_group.add_argument(
            "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"]
        )
972
973
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
974
975
976
977
978
            **scheduler_kwargs["long_prefill_token_threshold"],
        )
        scheduler_group.add_argument(
            "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"]
        )
979
980
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
981
        scheduler_group.add_argument(
982
983
            "--scheduling-policy", **scheduler_kwargs["policy"]
        )
984
        scheduler_group.add_argument(
985
986
987
988
989
990
991
992
            "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"]
        )
        scheduler_group.add_argument(
            "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
        )
        scheduler_group.add_argument(
            "--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
        )
993
994
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
995
996
997
998
999
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"],
        )
        scheduler_group.add_argument(
            "--async-scheduling", **scheduler_kwargs["async_scheduling"]
        )
1000
1001

        # vLLM arguments
1002
        vllm_kwargs = get_kwargs(VllmConfig)
1003
1004
1005
1006
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
1007
1008
1009
1010
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
        vllm_group.add_argument(
            "--speculative-config", **vllm_kwargs["speculative_config"]
        )
        vllm_group.add_argument(
            "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
        )
        vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument(
            "--compilation-config", "-O", **vllm_kwargs["compilation_config"]
        )
        vllm_group.add_argument(
            "--additional-config", **vllm_kwargs["additional_config"]
        )
        vllm_group.add_argument(
            "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
        )
1027

1028
        # Other arguments
1029
1030
1031
1032
1033
        parser.add_argument(
            "--disable-log-stats",
            action="store_true",
            help="Disable logging statistics.",
        )
1034

1035
        return parser
1036
1037

    @classmethod
1038
    def from_cli_args(cls, args: argparse.Namespace):
1039
1040
1041
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
1042
1043
1044
        engine_args = cls(
            **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
        )
Zhuohan Li's avatar
Zhuohan Li committed
1045
        return engine_args
1046

1047
    def create_model_config(self) -> ModelConfig:
1048
1049
1050
1051
1052
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
1053
1054
1055
1056
1057
1058
        if (
            not isinstance(self, AsyncEngineArgs)
            and envs.VLLM_CI_USE_S3
            and self.model in MODELS_ON_S3
            and self.load_format == "auto"
        ):
1059
1060
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"

1061
1062
1063
1064
        if self.disable_mm_preprocessor_cache:
            logger.warning(
                "`--disable-mm-preprocessor-cache` is deprecated "
                "and will be removed in v0.13. "
1065
1066
                "Please use `--mm-processor-cache-gb 0` instead.",
            )
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078

            self.mm_processor_cache_gb = 0
        elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
            logger.warning(
                "VLLM_MM_INPUT_CACHE_GIB` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-processor-cache-gb %d` instead.",
                envs.VLLM_MM_INPUT_CACHE_GIB,
            )

            self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB

1079
1080
1081
1082
        if self.enable_multimodal_encoder_data_parallel:
            logger.warning(
                "--enable-multimodal-encoder-data-parallel` is deprecated "
                "and will be removed in v0.13. "
1083
1084
                "Please use `--mm-encoder-tp-mode data` instead."
            )
1085
1086
1087

            self.mm_encoder_tp_mode = "data"

1088
        return ModelConfig(
1089
            model=self.model,
1090
            hf_config_path=self.hf_config_path,
1091
1092
            runner=self.runner,
            convert=self.convert,
1093
            task=self.task,
1094
            tokenizer=self.tokenizer,
1095
1096
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1097
            allowed_local_media_path=self.allowed_local_media_path,
1098
            allowed_media_domains=self.allowed_media_domains,
1099
1100
1101
1102
1103
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1104
            rope_theta=self.rope_theta,
1105
            hf_token=self.hf_token,
1106
            hf_overrides=self.hf_overrides,
1107
1108
1109
1110
1111
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_logprobs=self.max_logprobs,
1112
            logprobs_mode=self.logprobs_mode,
1113
            disable_sliding_window=self.disable_sliding_window,
1114
            disable_cascade_attn=self.disable_cascade_attn,
1115
            skip_tokenizer_init=self.skip_tokenizer_init,
1116
            enable_prompt_embeds=self.enable_prompt_embeds,
1117
            served_model_name=self.served_model_name,
1118
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1119
            interleave_mm_strings=self.interleave_mm_strings,
1120
            media_io_kwargs=self.media_io_kwargs,
1121
            skip_mm_profiling=self.skip_mm_profiling,
1122
            config_format=self.config_format,
1123
            mm_processor_kwargs=self.mm_processor_kwargs,
1124
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1125
            mm_processor_cache_type=self.mm_processor_cache_type,
1126
            mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
1127
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1128
            pooler_config=self.pooler_config,
1129
            override_pooler_config=self.override_pooler_config,
1130
            logits_processor_pattern=self.logits_processor_pattern,
1131
            generation_config=self.generation_config,
1132
            override_generation_config=self.override_generation_config,
1133
            enable_sleep_mode=self.enable_sleep_mode,
1134
            model_impl=self.model_impl,
1135
            override_attention_dtype=self.override_attention_dtype,
1136
            logits_processors=self.logits_processors,
1137
            video_pruning_rate=self.video_pruning_rate,
1138
            io_processor_plugin=self.io_processor_plugin,
1139
        )
1140

1141
    def validate_tensorizer_args(self):
1142
1143
        from vllm.model_executor.model_loader.tensorizer import TensorizerConfig

1144
1145
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
1146
1147
1148
                self.model_loader_extra_config["tensorizer_config"][key] = (
                    self.model_loader_extra_config[key]
                )
1149

1150
    def create_load_config(self) -> LoadConfig:
1151
1152
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1153

1154
1155
1156
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
1157
1158
                    self.model_loader_extra_config.to_serializable()
                )
1159
            self.model_loader_extra_config["tensorizer_config"] = {}
1160
1161
1162
            self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = (
                self.model
            )
1163
            self.validate_tensorizer_args()
1164

1165
1166
1167
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1168
            safetensors_load_strategy=self.safetensors_load_strategy,
1169
            device="cpu" if is_online_quantization(self.quantization) else None,
1170
1171
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1172
            use_tqdm_on_load=self.use_tqdm_on_load,
1173
            pt_load_map_location=self.pt_load_map_location,
1174
        )
1175

1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1189
        dictionary from the engine.
1190
1191
        """
        if self.speculative_config is None:
1192
            return None
1193

1194
1195
1196
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
1197
1198
1199
1200
1201
1202
1203
1204
        self.speculative_config.update(
            {
                "target_model_config": target_model_config,
                "target_parallel_config": target_parallel_config,
                "enable_chunked_prefill": enable_chunked_prefill,
                "disable_log_stats": disable_log_stats,
            }
        )
1205
        return SpeculativeConfig(**self.speculative_config)
1206

1207
1208
1209
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
1210
        headless: bool = False,
1211
1212
1213
1214
1215
1216
1217
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1218

1219
1220
1221
1222
1223
1224
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1225
        current_platform.pre_register_and_update()
1226

1227
        device_config = DeviceConfig(device=cast(Device, current_platform.device_type))
1228

1229
1230
1231
1232
        model_config = self.create_model_config()
        self.model = model_config.model
        self.tokenizer = model_config.tokenizer

1233
1234
1235
1236
1237
1238
1239
1240
1241
        (self.model, self.tokenizer, self.speculative_config) = (
            maybe_override_with_speculators(
                model=self.model,
                tokenizer=self.tokenizer,
                revision=self.revision,
                trust_remote_code=self.trust_remote_code,
                vllm_speculative_config=self.speculative_config,
            )
        )
1242

1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

1260
1261
        # Set default arguments for V1 Engine.
        self._set_default_args(usage_context, model_config)
1262
        # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
        if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
            CpuArchEnum.POWERPC,
            CpuArchEnum.S390X,
            CpuArchEnum.ARM,
            CpuArchEnum.RISCV,
        ):
            logger.info(
                "Chunked prefill is not supported for ARM and POWER, "
                "S390X and RISC-V CPUs; "
                "disabling it for V1 backend."
            )
1274
            self.enable_chunked_prefill = False
1275
1276
        assert self.enable_chunked_prefill is not None

1277
1278
1279
1280
1281
1282
1283
        sliding_window: Optional[int] = None
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1284
1285
1286
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1287
        # reuses the GPUs of TP group, and split one TP group into
1288
        # tp_size//dcp_size DCP groups.
1289
        assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, (
1290
1291
1292
1293
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1294
        cache_config = CacheConfig(
1295
            block_size=self.block_size,
1296
            gpu_memory_utilization=self.gpu_memory_utilization,
1297
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1298
1299
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1300
            is_attention_free=model_config.is_attention_free,
1301
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1302
            sliding_window=sliding_window,
1303
            enable_prefix_caching=self.enable_prefix_caching,
1304
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1305
            cpu_offload_gb=self.cpu_offload_gb,
1306
            calculate_kv_scales=self.calculate_kv_scales,
1307
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1308
1309
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1310
        )
1311

1312
1313
1314
1315
1316
1317
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
1318

1319
1320
1321
            ray_runtime_env = ray.get_runtime_context().runtime_env
            logger.info("Using ray runtime env: %s", ray_runtime_env)

1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1333
        assert not headless or not self.data_parallel_hybrid_lb, (
1334
1335
            "data_parallel_hybrid_lb is not applicable in headless mode"
        )
1336

1337
        data_parallel_external_lb = self.data_parallel_rank is not None
1338
        # Local DP rank = 1, use pure-external LB.
1339
1340
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
1341
1342
                "data_parallel_size_local must be 1 when data_parallel_rank is set"
            )
1343
            data_parallel_size_local = 1
1344
1345
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1346
1347
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

            self.data_parallel_rank = self.data_parallel_start_rank or 0
1363
        else:
1364
            assert not self.data_parallel_hybrid_lb, (
1365
1366
                "data_parallel_size_local must be set to use data_parallel_hybrid_lb."
            )
1367

1368
1369
            # Local DP size defaults to global DP size if not set.
            data_parallel_size_local = self.data_parallel_size
1370
1371
1372

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1373
1374
1375
1376
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
1377
1378
                    "Using host IP %s as ray-based data parallel address", host_ip
                )
Rui Qiao's avatar
Rui Qiao committed
1379
1380
1381
1382
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
1383
1384
                    self.data_parallel_backend,
                )
Rui Qiao's avatar
Rui Qiao committed
1385
1386
1387
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1388
1389
1390

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
1391
        data_parallel_rpc_port = (
1392
            self.data_parallel_rpc_port
1393
1394
1395
            if (self.data_parallel_rpc_port is not None)
            else ParallelConfig.data_parallel_rpc_port
        )
1396

1397
1398
1399
1400
        if self.async_scheduling:
            # Async scheduling does not work with the uniprocess backend.
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "mp"
1401
1402
1403
1404
                logger.info(
                    "Defaulting to mp-based distributed executor "
                    "backend for async scheduling."
                )
1405
            if self.pipeline_parallel_size > 1:
1406
1407
1408
                raise ValueError(
                    "Async scheduling is not supported with pipeline-parallel-size > 1."
                )
1409
1410
1411
1412
1413
1414

            # Currently, async scheduling does not support speculative decoding.
            # TODO(woosuk): Support it.
            if self.speculative_config is not None:
                raise ValueError(
                    "Currently, speculative decoding is not supported with "
1415
1416
                    "async scheduling."
                )
1417

1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
        # Forward the deprecated CLI args to the EPLB config.
        if self.num_redundant_experts is not None:
            self.eplb_config.num_redundant_experts = self.num_redundant_experts
        if self.eplb_window_size is not None:
            self.eplb_config.window_size = self.eplb_window_size
        if self.eplb_step_interval is not None:
            self.eplb_config.step_interval = self.eplb_step_interval
        if self.eplb_log_balancedness is not None:
            self.eplb_config.log_balancedness = self.eplb_log_balancedness

1428
        parallel_config = ParallelConfig(
1429
1430
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1431
            data_parallel_size=self.data_parallel_size,
1432
1433
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1434
1435
1436
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1437
            data_parallel_backend=self.data_parallel_backend,
1438
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1439
            enable_expert_parallel=self.enable_expert_parallel,
1440
1441
            enable_dbo=self.enable_dbo,
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1442
            dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
1443
            enable_eplb=self.enable_eplb,
1444
            eplb_config=self.eplb_config,
1445
            expert_placement_strategy=self.expert_placement_strategy,
1446
1447
1448
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1449
            ray_runtime_env=ray_runtime_env,
1450
            placement_group=placement_group,
1451
1452
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1453
            worker_extension_cls=self.worker_extension_cls,
1454
            decode_context_parallel_size=self.decode_context_parallel_size,
1455
1456
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1457
        )
1458

1459
        speculative_config = self.create_speculative_config(
1460
1461
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1462
            enable_chunked_prefill=self.enable_chunked_prefill,
1463
            disable_log_stats=self.disable_log_stats,
1464
1465
        )

1466
1467
1468
1469
1470
        # make sure num_lookahead_slots is set appropriately depending on
        # whether speculative decoding is enabled
        num_lookahead_slots = self.num_lookahead_slots
        if speculative_config is not None:
            num_lookahead_slots = speculative_config.num_lookahead_slots
1471

1472
        scheduler_config = SchedulerConfig(
1473
            runner_type=model_config.runner_type,
1474
1475
1476
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1477
            cuda_graph_sizes=self.cuda_graph_sizes,
1478
            num_lookahead_slots=num_lookahead_slots,
1479
            enable_chunked_prefill=self.enable_chunked_prefill,
1480
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1481
            is_multimodal_model=model_config.is_multimodal_model,
1482
            is_encoder_decoder=model_config.is_encoder_decoder,
1483
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray),
1484
            policy=self.scheduling_policy,
1485
            scheduler_cls=self.scheduler_cls,
1486
1487
1488
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1489
            disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
1490
            async_scheduling=self.async_scheduling,
1491
        )
1492

1493
1494
1495
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
                "non multimodal model"
            )

        lora_config = (
            LoRAConfig(
                bias_enabled=self.enable_lora_bias,
                max_lora_rank=self.max_lora_rank,
                max_loras=self.max_loras,
                default_mm_loras=self.default_mm_loras,
                fully_sharded_loras=self.fully_sharded_loras,
                lora_extra_vocab_size=self.lora_extra_vocab_size,
                lora_dtype=self.lora_dtype,
                max_cpu_loras=self.max_cpu_loras
                if self.max_cpu_loras and self.max_cpu_loras > 0
                else None,
            )
            if self.enable_lora
            else None
        )
1515

1516
1517
1518
1519
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1520
        load_config = self.create_load_config()
1521

1522
1523
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
1524
            self.structured_outputs_config.reasoning_parser = self.reasoning_parser
1525
1526
1527
1528

        # Forward the deprecated CLI args to the StructuredOutputsConfig
        so_config = self.structured_outputs_config
        if self.guided_decoding_backend is not None:
1529
            so_config.guided_decoding_backend = self.guided_decoding_backend
1530
        if self.guided_decoding_disable_fallback is not None:
1531
1532
1533
            so_config.guided_decoding_disable_fallback = (
                self.guided_decoding_disable_fallback
            )
1534
        if self.guided_decoding_disable_any_whitespace is not None:
1535
1536
1537
            so_config.guided_decoding_disable_any_whitespace = (
                self.guided_decoding_disable_any_whitespace
            )
1538
        if self.guided_decoding_disable_additional_properties is not None:
1539
1540
1541
            so_config.guided_decoding_disable_additional_properties = (
                self.guided_decoding_disable_additional_properties
            )
1542

1543
        observability_config = ObservabilityConfig(
1544
            show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version),
1545
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1546
            collect_detailed_traces=self.collect_detailed_traces,
1547
        )
1548

1549
        config = VllmConfig(
1550
1551
1552
1553
1554
1555
1556
1557
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
1558
            structured_outputs_config=self.structured_outputs_config,
1559
            observability_config=observability_config,
1560
            compilation_config=self.compilation_config,
1561
            kv_transfer_config=self.kv_transfer_config,
1562
            kv_events_config=self.kv_events_config,
1563
            additional_config=self.additional_config,
1564
        )
1565

1566
1567
        return config

1568
1569
1570
1571
1572
1573
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1574
1575
1576
1577
        if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
            _raise_or_fallback(
                feature_name="--logits-processor-pattern", recommend_to_remove=False
            )
1578
1579
            return False

1580
        # No Mamba or Encoder-Decoder so far.
1581
        if not model_config.is_v1_compatible:
1582
1583
1584
            _raise_or_fallback(
                feature_name=model_config.architectures, recommend_to_remove=False
            )
1585
1586
1587
            return False

        # No Concurrent Partial Prefills so far.
1588
1589
1590
1591
1592
1593
1594
1595
        if (
            self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
            or self.max_long_partial_prefills
            != SchedulerConfig.max_long_partial_prefills
        ):
            _raise_or_fallback(
                feature_name="Concurrent Partial Prefill", recommend_to_remove=False
            )
1596
1597
            return False

1598
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
        if self.speculative_config is not None:
            # speculative_config could still be a dict at this point
            if isinstance(self.speculative_config, dict):
                method = self.speculative_config.get("method", None)
            else:
                method = self.speculative_config.method

            if method == "draft_model":
                raise NotImplementedError(
                    "Draft model speculative decoding is not supported yet. "
                    "Please consider using other speculative decoding methods "
1610
1611
                    "such as ngram, medusa, eagle, or mtp."
                )
1612
1613

        V1_BACKENDS = [
1614
1615
            "FLASH_ATTN",
            "PALLAS",
1616
            "TRITON_ATTN",
1617
            "TRITON_MLA",
1618
            "CUTLASS_MLA",
1619
            "FLASHMLA",
1620
            "FLASH_ATTN_MLA",
1621
            "FLASHINFER",
1622
            "FLASHINFER_MLA",
1623
            "ROCM_AITER_MLA",
1624
            "TORCH_SDPA",
1625
            "FLEX_ATTENTION",
1626
            "TREE_ATTN",
1627
1628
            "XFORMERS",
            "ROCM_ATTN",
1629
        ]
1630
1631
1632
1633
        if (
            envs.is_set("VLLM_ATTENTION_BACKEND")
            and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS
        ):
1634
1635
1636
1637
1638
1639
1640
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

        #############################################################
        # Experimental Features - allow users to opt in.

1641
        if self.pipeline_parallel_size > 1:
1642
1643
1644
            supports_pp = getattr(
                self.distributed_executor_backend, "supports_pp", False
            )
1645
            if not supports_pp and self.distributed_executor_backend not in (
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
                ParallelConfig.distributed_executor_backend,
                "ray",
                "mp",
                "external_launcher",
            ):
                name = (
                    "Pipeline Parallelism without Ray distributed "
                    "executor or multiprocessing executor or external "
                    "launcher"
                )
                _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1657
                return False
1658

1659
1660
1661
1662
        if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
            _raise_or_fallback(
                feature_name="sliding window (CPU backend)", recommend_to_remove=False
            )
1663
1664
            return False

1665
1666
1667
1668
        #############################################################

        return True

1669
1670
1671
    def _set_default_args(
        self, usage_context: UsageContext, model_config: ModelConfig
    ) -> None:
1672
        """Set Default Arguments for V1 Engine."""
1673

1674
1675
1676
1677
1678
        # V1 always uses chunked prefills and prefix caching
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
1679
1680
1681

            # TODO: When prefix caching supports prompt embeds inputs, this
            # check can be removed.
1682
            if self.enable_prompt_embeds and self.enable_prefix_caching is not False:
1683
1684
1685
                logger.warning(
                    "--enable-prompt-embeds and --enable-prefix-caching "
                    "are not supported together in V1. Prefix caching has "
1686
1687
                    "been disabled."
                )
1688
1689
                self.enable_prefix_caching = False

1690
            if self.enable_prefix_caching is None:
1691
1692
1693
1694
1695
1696
                # Disable prefix caching default for hybrid models
                # since the feature is still experimental.
                if model_config.is_hybrid:
                    self.enable_prefix_caching = False
                else:
                    self.enable_prefix_caching = True
1697
1698
        else:
            pooling_type = model_config.pooler_config.pooling_type
1699
            is_causal = getattr(model_config.hf_config, "is_causal", True)
1700
1701
1702
1703
1704
            incremental_prefill_supported = (
                pooling_type is not None
                and pooling_type.lower() == "last"
                and is_causal
            )
1705

1706
            action = "Enabling" if incremental_prefill_supported else "Disabling"
1707
1708
1709
1710
1711
1712
1713
1714

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)

1715
1716
1717
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1718
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1719

1720
1721
        # When no user override, set the default values based on the usage
        # context.
1722
        # Use different default values for different hardware.
1723
1724
1725
1726
1727
1728
1729

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1730
            device_memory = current_platform.get_device_total_memory()
1731
            device_name = current_platform.get_device_name().lower()
1732
1733
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1734
            device_memory = 0
1735

1736
1737
1738
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
1739
        from vllm.usage.usage_lib import UsageContext
1740

1741
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1742
            # For GPUs like H100 and MI300x, use larger default values.
1743
1744
1745
1746
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1747
1748
1749
1750
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1751
1752
1753
1754
1755
1756
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1757
1758
1759
1760
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1761

1762
1763
1764
1765
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
1766
1767
1768
                    "V6E": 2048,
                    "V5E": 1024,
                    "V5P": 512,
1769
1770
                },
                UsageContext.OPENAI_API_SERVER: {
1771
1772
1773
1774
                    "V6E": 1024,
                    "V5E": 512,
                    "V5P": 256,
                },
1775
1776
            }

1777
1778
        # cpu specific default values.
        if current_platform.is_cpu():
1779
            world_size = self.pipeline_parallel_size * self.tensor_parallel_size
1780
            default_max_num_batched_tokens = {
1781
1782
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1783
1784
            }
            default_max_num_seqs = {
1785
1786
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1787
1788
            }

1789
        use_context_value = usage_context.value if usage_context else None
1790
1791
1792
1793
        if (
            self.max_num_batched_tokens is None
            and usage_context in default_max_num_batched_tokens
        ):
1794
1795
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
1796
1797
1798
1799
                if chip_name in default_max_num_batched_tokens_tpu[usage_context]:
                    self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[
                        usage_context
                    ][chip_name]
1800
                else:
1801
1802
1803
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1804
            else:
1805
1806
1807
                if not self.enable_chunked_prefill:
                    self.max_num_batched_tokens = model_config.max_model_len
                else:
1808
1809
1810
                    self.max_num_batched_tokens = default_max_num_batched_tokens[
                        usage_context
                    ]
1811
            logger.debug(
1812
                "Setting max_num_batched_tokens to %d for %s usage context.",
1813
1814
1815
                self.max_num_batched_tokens,
                use_context_value,
            )
1816

1817
1818
1819
1820
1821
        if self.max_num_seqs is None and usage_context in default_max_num_seqs:
            self.max_num_seqs = min(
                default_max_num_seqs[usage_context],
                self.max_num_batched_tokens or sys.maxsize,
            )
1822

1823
1824
1825
1826
1827
            logger.debug(
                "Setting max_num_seqs to %d for %s usage context.",
                self.max_num_seqs,
                use_context_value,
            )
1828

1829

1830
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1831
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1832
    """Arguments for asynchronous vLLM engine."""
1833

1834
1835
1836
1837
1838
1839
    enable_log_requests: bool = False

    @property
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1840
1841
        "`enable_log_requests` instead."
    )
1842
1843
1844
1845
1846
1847
1848
    def disable_log_requests(self) -> bool:
        return not self.enable_log_requests

    @disable_log_requests.setter
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
1849
1850
        "`enable_log_requests` instead."
    )
1851
1852
    def disable_log_requests(self, value: bool):
        self.enable_log_requests = not value
1853
1854

    @staticmethod
1855
1856
1857
    def add_cli_args(
        parser: FlexibleArgumentParser, async_args_only: bool = False
    ) -> FlexibleArgumentParser:
1858
        # Initialize plugin to update the parser, for example, The plugin may
1859
        # add a new kind of quantization method to --quantization argument or
1860
1861
        # a new device to --device argument.
        load_general_plugins()
1862
1863
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
        parser.add_argument(
            "--enable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=AsyncEngineArgs.enable_log_requests,
            help="Enable logging requests.",
        )
        parser.add_argument(
            "--disable-log-requests",
            action=argparse.BooleanOptionalAction,
            default=not AsyncEngineArgs.enable_log_requests,
            help="[DEPRECATED] Disable logging requests.",
            deprecated=True,
        )
1877
        current_platform.pre_register_and_update(parser)
1878
        return parser
1879
1880


1881
1882
1883
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
1884
1885
            f"VLLM_USE_V1=1 is not supported with {feature_name}."
        )
1886
1887
1888
1889
1890
1891
1892
1893
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


1894
1895
1896
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1897

1898
1899
1900
1901
1902
1903
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
1904
    match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
1905
1906
    if match:
        decimal_multiplier = {
1907
1908
1909
            "k": 10**3,
            "m": 10**6,
            "g": 10**9,
1910
1911
        }
        binary_multiplier = {
1912
1913
1914
            "K": 2**10,
            "M": 2**20,
            "G": 2**30,
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
1927
1928
1929
1930
1931
                raise argparse.ArgumentTypeError(
                    "Decimals are not allowed "
                    f"with binary suffixes like {suffix}. Did you mean to use "
                    f"{number}{suffix.lower()} instead?"
                ) from e
1932
1933
1934

    # Regular plain number.
    return int(value)