arg_utils.py 82.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# yapf: disable
5
import os
6
import argparse
7
import copy
8
import dataclasses
9
import functools
10
import json
11
import sys
12
import threading
13
import warnings
14
from dataclasses import MISSING, dataclass, fields, is_dataclass
15
from itertools import permutations
16
17
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
                    Type, TypeVar, Union, cast, get_args, get_origin)
18

19
import regex as re
20
import torch
21
from pydantic import TypeAdapter, ValidationError
22
from typing_extensions import TypeIs, deprecated
23

24
import vllm.envs as envs
25
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
26
27
28
29
                         ConfigFormat, ConfigType, DecodingConfig,
                         DetailedTraceModules, Device, DeviceConfig,
                         DistributedExecutorBackend, GuidedDecodingBackend,
                         GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
30
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
31
                         ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
32
                         ObservabilityConfig, ParallelConfig, PoolerConfig,
33
                         PrefixCachingHashAlgo, PromptAdapterConfig,
34
                         SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
35
36
                         TaskOption, TokenizerMode, TokenizerPoolConfig,
                         VllmConfig, get_attr_docs, get_field)
37
from vllm.executor.executor_base import ExecutorBase
38
from vllm.logger import init_logger
39
from vllm.model_executor.layers.quantization import QuantizationMethods
40
from vllm.plugins import load_general_plugins
41
from vllm.reasoning import ReasoningParserManager
42
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
43
from vllm.transformers_utils.utils import check_gguf_file
44
from vllm.usage.usage_lib import UsageContext
45
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
Rui Qiao's avatar
Rui Qiao committed
46
                        GiB_bytes, get_ip, is_in_ray_actor)
47
48

# yapf: enable
49

50
51
logger = init_logger(__name__)

52
53
54
55
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
56

57

58
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
59

60
    def _parse_type(val: str) -> T:
61
62
63
64
65
66
67
        try:
            if return_type is json.loads and not re.match("^{.*}$", val):
                return cast(T, nullable_kvs(val))
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
                f"Value {val} cannot be converted to {return_type}.") from e
68

69
70
71
72
73
74
75
76
77
78
79
    return _parse_type


def optional_type(
        return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:

    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

80
    return _optional_type
81
82


83
84
85
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
    if not re.match("^{.*}$", val):
        return str(val)
86
    return optional_type(json.loads)(val)
87
88


89
90
91
92
93
@deprecated(
    "Passing a JSON argument as a string containing comma separated key=value "
    "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
    "string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
94
95
96
97
98
99
100
101
102
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
103
    out_dict: dict[str, int] = {}
104
    for item in val.split(","):
105
106
107
108
109
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
110
111

        try:
112
            parsed_value = int(value)
113
114
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
115
116
117
118
119
120
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
121
122
123
124

    return out_dict


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


140
141
142
143
144
145
146
147
148
149
150
151
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
    """Convert Literal type hints to argparse kwargs."""
    type_hint = get_type(type_hints, Literal)
    choices = get_args(type_hint)
    choice_type = type(choices[0])
    if not all(isinstance(choice, choice_type) for choice in choices):
        raise ValueError(
            "All choices must be of the same type. "
            f"Got {choices} with types {[type(c) for c in choices]}")
    return {"type": choice_type, "choices": sorted(choices)}


152
153
154
155
156
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
    elif origin is Union:
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


174
175
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
176
177
178
    cls_docs = get_attr_docs(cls)
    kwargs = {}
    for field in fields(cls):
179
        # Get the set of possible types for the field
180
        type_hints: set[TypeHint] = get_type_hints(field.type)
181
182
183
184
185

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

186
        # Get the default value of the field
187
188
189
        if field.default is not MISSING:
            default = field.default
        elif field.default_factory is not MISSING:
190
191
192
193
            default = field.default_factory()

        # Get the help text for the field
        name = field.name
194
        help = cls_docs[name].strip()
195
196
197
198
199
200
201
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
202
203
204
205
        json_tip = """\n\nShould either be a valid JSON string or JSON keys
        passed individually. For example, the following sets of arguments are
        equivalent:\n\n
        - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
206
207
208
209
        - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n
        Additionally, list elements can be passed individually using '+':
        - `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
        - `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n"""
210
        if dataclass_cls is not None:
211
212
213
214
215
216
217
218
219
220

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    if hasattr(cls, "from_cli"):
                        return cls.from_cli(val)
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
221
222
            kwargs[name]["help"] += json_tip
        elif contains_type(type_hints, bool):
223
224
225
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
226
            kwargs[name].update(literal_to_kwargs(type_hints))
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
                f"type. Got {types}.")
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
            assert len(types) == 1, (
                "List type must have exactly one type. Got "
                f"{type_hint} with types {types}")
            kwargs[name]["type"] = types[0]
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
246
            # Special case for large integers
247
            if name in {"max_model_len", "max_num_batched_tokens"}:
248
                kwargs[name]["type"] = human_readable_int
249
250
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
251
252
253
        elif (contains_type(type_hints, dict)
              and (contains_type(type_hints, str)
                   or any(is_not_builtin(th) for th in type_hints))):
254
            kwargs[name]["type"] = union_dict_and_str
255
        elif contains_type(type_hints, dict):
256
            kwargs[name]["type"] = parse_type(json.loads)
257
            kwargs[name]["help"] += json_tip
258
259
260
261
262
263
264
        elif (contains_type(type_hints, str)
              or any(is_not_builtin(th) for th in type_hints)):
            kwargs[name]["type"] = str
        else:
            raise ValueError(
                f"Unsupported type {type_hints} for argument {name}.")

265
266
267
268
269
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

270
271
272
273
274
275
276
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
277
278


279
280
281
282
283
284
285
286
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    """Return argparse kwargs for the given Config dataclass.

    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))
287
288
class EnvironmentConfigError(Exception):
    pass
289

290
291
292
293
def check_incompatible_config(env1: bool, env2: bool):
    if env1 is True and env2 is True:
        _s = "USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and USE_FUSED_RMS_QUANT must not be enabled simultaneously!\n\n"
        raise EnvironmentConfigError(_s)
294

295
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
296
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
297
    """Arguments for vLLM engine."""
298
299
300
301
302
303
304
    model: str = ModelConfig.model
    served_model_name: Optional[Union[
        str, List[str]]] = ModelConfig.served_model_name
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
    task: TaskOption = ModelConfig.task
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
305
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
306
307
308
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
309
310
    download_dir: Optional[str] = LoadConfig.download_dir
    load_format: str = LoadConfig.load_format
311
312
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
313
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
314
315
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
316
317
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig,
                                            "cuda_graph_sizes")
318
319
320
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
321
    distributed_executor_backend: Optional[Union[
322
323
        DistributedExecutorBackend,
        Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
324
    # number of P/D disaggregation (or other disaggregation) workers
325
326
327
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
    data_parallel_size: int = ParallelConfig.data_parallel_size
328
    data_parallel_rank: Optional[int] = None
329
330
331
    data_parallel_size_local: Optional[int] = None
    data_parallel_address: Optional[str] = None
    data_parallel_rpc_port: Optional[int] = None
Rui Qiao's avatar
Rui Qiao committed
332
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
333
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
334
335
336
337
338
    enable_eplb: bool = ParallelConfig.enable_eplb
    num_redundant_experts: int = ParallelConfig.num_redundant_experts
    eplb_window_size: int = ParallelConfig.eplb_window_size
    eplb_step_interval: int = ParallelConfig.eplb_step_interval
    eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness
339
340
    max_parallel_loading_workers: Optional[
        int] = ParallelConfig.max_parallel_loading_workers
341
342
343
344
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
    prefix_caching_hash_algo: PrefixCachingHashAlgo = \
        CacheConfig.prefix_caching_hash_algo
345
346
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
347
    use_v2_block_manager: bool = True
348
349
350
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
351
352
353
354
355
356
357
    max_num_batched_tokens: Optional[
        int] = SchedulerConfig.max_num_batched_tokens
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
    long_prefill_token_threshold: int = \
        SchedulerConfig.long_prefill_token_threshold
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
358
    max_logprobs: int = ModelConfig.max_logprobs
359
    disable_log_stats: bool = False
360
361
362
363
364
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
365
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
366
367
368
369
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
    max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
370
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
371
372
373
    # The following three fields are deprecated and will be removed in a future
    # release. Setting them will have no effect. Please remove them from your
    # configurations.
374
    tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
375
376
    tokenizer_pool_type: str = TokenizerPoolConfig.pool_type
    tokenizer_pool_extra_config: dict = \
377
        get_field(TokenizerPoolConfig, "extra_config")
378
    limit_mm_per_prompt: dict[str, int] = \
379
        get_field(MultiModalConfig, "limit_per_prompt")
380
381
382
    media_io_kwargs: dict[str, dict[str,
                                    Any]] = get_field(MultiModalConfig,
                                                      "media_io_kwargs")
383
384
385
386
    mm_processor_kwargs: Optional[Dict[str, Any]] = \
        MultiModalConfig.mm_processor_kwargs
    disable_mm_preprocessor_cache: bool = \
        MultiModalConfig.disable_mm_preprocessor_cache
387
    # LoRA fields
388
    enable_lora: bool = False
389
390
391
392
393
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
zhuwenwen's avatar
zhuwenwen committed
394
    lora_target_modules: Optional[List[str]] = LoRAConfig.lora_target_modules
395
396
397
398
399
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
    long_lora_scaling_factors: Optional[tuple[float, ...]] = \
        LoRAConfig.long_lora_scaling_factors
    # PromptAdapter fields
400
    enable_prompt_adapter: bool = False
401
402
403
    max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
    max_prompt_adapter_token: int = \
        PromptAdapterConfig.max_prompt_adapter_token
404
    device: Device = DeviceConfig.device
405
406
    num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
    multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
407
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
408
409
    num_gpu_blocks_override: Optional[
        int] = CacheConfig.num_gpu_blocks_override
410
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
411
412
    model_loader_extra_config: dict = \
        get_field(LoadConfig, "model_loader_extra_config")
413
414
    ignore_patterns: Optional[Union[str,
                                    List[str]]] = LoadConfig.ignore_patterns
415
    preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
416

417
418
419
420
    scheduler_delay_factor: float = SchedulerConfig.delay_factor
    enable_chunked_prefill: Optional[
        bool] = SchedulerConfig.enable_chunked_prefill
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
421

422
423
424
    disable_hybrid_kv_cache_manager: bool = (
        SchedulerConfig.disable_hybrid_kv_cache_manager)

425
426
427
428
429
430
    guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
    guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
    guided_decoding_disable_any_whitespace: bool = \
        DecodingConfig.disable_any_whitespace
    guided_decoding_disable_additional_properties: bool = \
        DecodingConfig.disable_additional_properties
431
432
    logits_processor_pattern: Optional[
        str] = ModelConfig.logits_processor_pattern
433

434
    speculative_config: Optional[Dict[str, Any]] = None
zhuwenwen's avatar
zhuwenwen committed
435
    num_speculative_heads: Optional[int] = None
436

437
    qlora_adapter_name_or_path: Optional[str] = None
438
439
440
441
442
443
    show_hidden_metrics_for_version: Optional[str] = \
        ObservabilityConfig.show_hidden_metrics_for_version
    otlp_traces_endpoint: Optional[str] = \
        ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: Optional[list[DetailedTraceModules]] = \
        ObservabilityConfig.collect_detailed_traces
444
    disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
445
446
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
447

448
449
450
451
    override_neuron_config: dict[str, Any] = \
        get_field(ModelConfig, "override_neuron_config")
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
        ModelConfig.override_pooler_config
452
453
    compilation_config: CompilationConfig = \
        get_field(VllmConfig, "compilation_config")
454
455
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
456

457
    kv_transfer_config: Optional[KVTransferConfig] = None
458
    kv_events_config: Optional[KVEventsConfig] = None
459

460
461
462
463
464
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
    override_generation_config: dict[str, Any] = \
        get_field(ModelConfig, "override_generation_config")
    model_impl: str = ModelConfig.model_impl
465
    override_attention_dtype: str = ModelConfig.override_attention_dtype
466

467
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
468

469
470
    additional_config: dict[str, Any] = \
        get_field(VllmConfig, "additional_config")
471
472
473
    enable_reasoning: Optional[bool] = None  # DEPRECATED
    reasoning_parser: str = DecodingConfig.reasoning_backend

474
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
475
    pt_load_map_location: str = LoadConfig.pt_load_map_location
王敏's avatar
王敏 committed
476

477
478
    enable_multimodal_encoder_data_parallel: bool = \
        ParallelConfig.enable_multimodal_encoder_data_parallel
479
480
481
        
    enable_dp_attention: bool = \
        ParallelConfig.enable_dp_attention
王敏's avatar
王敏 committed
482

483
    def __post_init__(self):
484
485
486
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
487
        if isinstance(self.compilation_config, (int, dict)):
488
489
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
490
491
492
493
494
495
496
        if self.qlora_adapter_name_or_path is not None:
            warnings.warn(
                "The `qlora_adapter_name_or_path` is deprecated "
                "and will be removed in v0.10.0. ",
                DeprecationWarning,
                stacklevel=2,
            )
497
        # Setup plugins
498
499
        from vllm.plugins import load_general_plugins
        load_general_plugins()
500
501

    @staticmethod
502
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
503
        """Shared CLI arguments for vLLM engine."""
504

505
        # Model arguments
506
507
508
509
510
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
Reid's avatar
Reid committed
511
        if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]):
512
            model_group.add_argument("--model", **model_kwargs["model"])
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        model_group.add_argument("--task", **model_kwargs["task"])
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode",
                                 **model_kwargs["tokenizer_mode"])
        model_group.add_argument("--trust-remote-code",
                                 **model_kwargs["trust_remote_code"])
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
        model_group.add_argument("--hf-config-path",
                                 **model_kwargs["hf_config_path"])
        model_group.add_argument("--allowed-local-media-path",
                                 **model_kwargs["allowed_local_media_path"])
        model_group.add_argument("--revision", **model_kwargs["revision"])
        model_group.add_argument("--code-revision",
                                 **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling",
                                 **model_kwargs["rope_scaling"])
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
        model_group.add_argument("--tokenizer-revision",
                                 **model_kwargs["tokenizer_revision"])
        model_group.add_argument("--max-model-len",
                                 **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q",
                                 **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager",
                                 **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-seq-len-to-capture",
                                 **model_kwargs["max_seq_len_to_capture"])
        model_group.add_argument("--max-logprobs",
                                 **model_kwargs["max_logprobs"])
        model_group.add_argument("--disable-sliding-window",
                                 **model_kwargs["disable_sliding_window"])
        model_group.add_argument("--disable-cascade-attn",
                                 **model_kwargs["disable_cascade_attn"])
        model_group.add_argument("--skip-tokenizer-init",
                                 **model_kwargs["skip_tokenizer_init"])
549
550
        model_group.add_argument("--enable-prompt-embeds",
                                 **model_kwargs["enable_prompt_embeds"])
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        model_group.add_argument("--served-model-name",
                                 **model_kwargs["served_model_name"])
        # This one is a special case because it is the
        # opposite of ModelConfig.use_async_output_proc
        model_group.add_argument(
            "--disable-async-output-proc",
            action="store_true",
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
        model_group.add_argument("--config-format",
                                 choices=[f.value for f in ConfigFormat],
                                 **model_kwargs["config_format"])
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
        model_group.add_argument("--hf-token",
                                 type=str,
                                 nargs="?",
                                 const=True,
                                 default=model_kwargs["hf_token"]["default"],
                                 help=model_kwargs["hf_token"]["help"])
        model_group.add_argument("--hf-overrides",
                                 **model_kwargs["hf_overrides"])
        model_group.add_argument("--override-neuron-config",
                                 **model_kwargs["override_neuron_config"])
        model_group.add_argument("--override-pooler-config",
                                 **model_kwargs["override_pooler_config"])
        model_group.add_argument("--logits-processor-pattern",
                                 **model_kwargs["logits_processor_pattern"])
        model_group.add_argument("--generation-config",
                                 **model_kwargs["generation_config"])
        model_group.add_argument("--override-generation-config",
                                 **model_kwargs["override_generation_config"])
        model_group.add_argument("--enable-sleep-mode",
                                 **model_kwargs["enable_sleep_mode"])
        model_group.add_argument("--model-impl",
                                 choices=[f.value for f in ModelImpl],
                                 **model_kwargs["model_impl"])
589
590
        model_group.add_argument("--override-attention-dtype",
                                 **model_kwargs["override_attention_dtype"])
591

592
593
594
595
596
597
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
598
        load_group.add_argument("--load-format",
599
600
                                choices=[f.value for f in LoadFormat],
                                **load_kwargs["load_format"])
601
        load_group.add_argument("--download-dir",
602
                                **load_kwargs["download_dir"])
603
        load_group.add_argument("--model-loader-extra-config",
604
                                **load_kwargs["model_loader_extra_config"])
605
606
607
        load_group.add_argument("--ignore-patterns",
                                **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load",
608
                                **load_kwargs["use_tqdm_on_load"])
609
610
        load_group.add_argument(
            "--qlora-adapter-name-or-path",
611
            type=str,
612
613
614
615
616
            default=None,
            help="The `--qlora-adapter-name-or-path` has no effect, do not set"
            " it, and it  will be removed in v0.10.0.",
            deprecated=True,
        )
617
618
        load_group.add_argument('--pt-load-map-location',
                                **load_kwargs["pt_load_map_location"])
619
620
621
622
623
624
625

        # Guided decoding arguments
        guided_decoding_kwargs = get_kwargs(DecodingConfig)
        guided_decoding_group = parser.add_argument_group(
            title="DecodingConfig",
            description=DecodingConfig.__doc__,
        )
626
627
        guided_decoding_group.add_argument("--guided-decoding-backend",
                                           **guided_decoding_kwargs["backend"])
628
        guided_decoding_group.add_argument(
629
630
            "--guided-decoding-disable-fallback",
            **guided_decoding_kwargs["disable_fallback"])
631
        guided_decoding_group.add_argument(
632
633
634
635
636
            "--guided-decoding-disable-any-whitespace",
            **guided_decoding_kwargs["disable_any_whitespace"])
        guided_decoding_group.add_argument(
            "--guided-decoding-disable-additional-properties",
            **guided_decoding_kwargs["disable_additional_properties"])
637
638
639
        guided_decoding_group.add_argument(
            "--enable-reasoning",
            action=argparse.BooleanOptionalAction,
640
            deprecated=True,
641
            help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as "
642
            "of v0.9.0. Use `--reasoning-parser` to specify the reasoning "
643
            "parser backend instead. This flag (`--enable-reasoning`) will be "
644
645
            "removed in v0.10.0. When `--reasoning-parser` is specified, "
            "reasoning mode is automatically enabled.")
646
647
648
649
650
651
        guided_decoding_group.add_argument(
            "--reasoning-parser",
            # This choices is a special case because it's not static
            choices=list(ReasoningParserManager.reasoning_parsers),
            **guided_decoding_kwargs["reasoning_backend"])

652
        # Parallel arguments
653
654
655
656
657
658
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
659
            "--distributed-executor-backend",
660
661
            **parallel_kwargs["distributed_executor_backend"])
        parallel_group.add_argument(
662
            "--pipeline-parallel-size", "-pp",
663
            **parallel_kwargs["pipeline_parallel_size"])
664
        parallel_group.add_argument("--tensor-parallel-size", "-tp",
665
                                    **parallel_kwargs["tensor_parallel_size"])
666
        parallel_group.add_argument("--data-parallel-size", "-dp",
667
                                    **parallel_kwargs["data_parallel_size"])
668
669
670
671
672
673
        parallel_group.add_argument(
            '--data-parallel-rank',
            '-dpn',
            type=int,
            help='Data parallel rank of this instance. '
            'When set, enables external load balancer mode.')
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
        parallel_group.add_argument('--data-parallel-size-local',
                                    '-dpl',
                                    type=int,
                                    help='Number of data parallel replicas '
                                    'to run on this node.')
        parallel_group.add_argument('--data-parallel-address',
                                    '-dpa',
                                    type=str,
                                    help='Address of data parallel cluster '
                                    'head-node.')
        parallel_group.add_argument('--data-parallel-rpc-port',
                                    '-dpp',
                                    type=int,
                                    help='Port for data parallel RPC '
                                    'communication.')
Rui Qiao's avatar
Rui Qiao committed
689
690
691
692
693
694
        parallel_group.add_argument('--data-parallel-backend',
                                    '-dpb',
                                    type=str,
                                    default='mp',
                                    help='Backend for data parallel, either '
                                    '"mp" or "ray".')
695
        parallel_group.add_argument(
696
            "--enable-expert-parallel",
697
            **parallel_kwargs["enable_expert_parallel"])
698
699
700
701
702
703
704
705
706
707
        parallel_group.add_argument("--enable-eplb",
                                    **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--num-redundant-experts",
                                    **parallel_kwargs["num_redundant_experts"])
        parallel_group.add_argument("--eplb-window-size",
                                    **parallel_kwargs["eplb_window_size"])
        parallel_group.add_argument("--eplb-step-interval",
                                    **parallel_kwargs["eplb_step_interval"])
        parallel_group.add_argument("--eplb-log-balancedness",
                                    **parallel_kwargs["eplb_log_balancedness"])
708
        parallel_group.add_argument(
709
            "--max-parallel-loading-workers",
710
711
            **parallel_kwargs["max_parallel_loading_workers"])
        parallel_group.add_argument(
712
            "--ray-workers-use-nsight",
713
714
            **parallel_kwargs["ray_workers_use_nsight"])
        parallel_group.add_argument(
715
            "--disable-custom-all-reduce",
716
            **parallel_kwargs["disable_custom_all_reduce"])
717
718
719
720
        parallel_group.add_argument("--worker-cls",
                                    **parallel_kwargs["worker_cls"])
        parallel_group.add_argument("--worker-extension-cls",
                                    **parallel_kwargs["worker_extension_cls"])
721
722
723
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
            **parallel_kwargs["enable_multimodal_encoder_data_parallel"])
724
725
726
        parallel_group.add_argument(
            "--enable-dp-attention",
            **parallel_kwargs["enable_dp_attention"])
727

728
729
730
731
732
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
733
        )
734
735
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
        cache_group.add_argument("--gpu-memory-utilization",
736
                                 **cache_kwargs["gpu_memory_utilization"])
737
738
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
        cache_group.add_argument("--kv-cache-dtype",
739
                                 **cache_kwargs["cache_dtype"])
740
        cache_group.add_argument("--num-gpu-blocks-override",
741
742
743
744
745
                                 **cache_kwargs["num_gpu_blocks_override"])
        cache_group.add_argument("--enable-prefix-caching",
                                 **cache_kwargs["enable_prefix_caching"])
        cache_group.add_argument("--prefix-caching-hash-algo",
                                 **cache_kwargs["prefix_caching_hash_algo"])
746
        cache_group.add_argument("--cpu-offload-gb",
747
                                 **cache_kwargs["cpu_offload_gb"])
748
        cache_group.add_argument("--calculate-kv-scales",
749
750
                                 **cache_kwargs["calculate_kv_scales"])

751
752
753
754
755
756
        # Tokenizer arguments
        tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
        tokenizer_group = parser.add_argument_group(
            title="TokenizerPoolConfig",
            description=TokenizerPoolConfig.__doc__,
        )
757
        tokenizer_group.add_argument("--tokenizer-pool-size",
758
                                     **tokenizer_kwargs["pool_size"])
759
        tokenizer_group.add_argument("--tokenizer-pool-type",
760
                                     **tokenizer_kwargs["pool_type"])
761
        tokenizer_group.add_argument("--tokenizer-pool-extra-config",
762
                                     **tokenizer_kwargs["extra_config"])
763
764

        # Multimodal related configs
765
766
767
768
769
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
770
        multimodal_group.add_argument("--limit-mm-per-prompt",
771
                                      **multimodal_kwargs["limit_per_prompt"])
772
773
        multimodal_group.add_argument("--media-io-kwargs",
                                      **multimodal_kwargs["media_io_kwargs"])
774
        multimodal_group.add_argument(
775
            "--mm-processor-kwargs",
776
777
            **multimodal_kwargs["mm_processor_kwargs"])
        multimodal_group.add_argument(
778
            "--disable-mm-preprocessor-cache",
779
            **multimodal_kwargs["disable_mm_preprocessor_cache"])
780

781
        # LoRA related configs
782
783
784
785
786
787
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
788
            "--enable-lora",
789
            action=argparse.BooleanOptionalAction,
790
791
            help="If True, enable handling of LoRA adapters.")
        lora_group.add_argument("--enable-lora-bias",
792
                                **lora_kwargs["bias_enabled"])
793
794
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
        lora_group.add_argument("--max-lora-rank",
795
                                **lora_kwargs["max_lora_rank"])
zhuwenwen's avatar
zhuwenwen committed
796
797
        lora_group.add_argument('--lora-target-modules',
                            **lora_kwargs["lora_target_modules"])
798
        lora_group.add_argument("--lora-extra-vocab-size",
799
800
                                **lora_kwargs["lora_extra_vocab_size"])
        lora_group.add_argument(
801
            "--lora-dtype",
802
803
            **lora_kwargs["lora_dtype"],
        )
804
        lora_group.add_argument("--long-lora-scaling-factors",
805
                                **lora_kwargs["long_lora_scaling_factors"])
806
        lora_group.add_argument("--max-cpu-loras",
807
                                **lora_kwargs["max_cpu_loras"])
808
        lora_group.add_argument("--fully-sharded-loras",
809
810
811
812
813
814
815
816
817
                                **lora_kwargs["fully_sharded_loras"])

        # PromptAdapter related configs
        prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
        prompt_adapter_group = parser.add_argument_group(
            title="PromptAdapterConfig",
            description=PromptAdapterConfig.__doc__,
        )
        prompt_adapter_group.add_argument(
818
            "--enable-prompt-adapter",
819
            action=argparse.BooleanOptionalAction,
820
            help="If True, enable handling of PromptAdapters.")
821
        prompt_adapter_group.add_argument(
822
            "--max-prompt-adapters",
823
824
            **prompt_adapter_kwargs["max_prompt_adapters"])
        prompt_adapter_group.add_argument(
825
            "--max-prompt-adapter-token",
826
            **prompt_adapter_kwargs["max_prompt_adapter_token"])
827
828
829
830
831
832
833

        # Device arguments
        device_kwargs = get_kwargs(DeviceConfig)
        device_group = parser.add_argument_group(
            title="DeviceConfig",
            description=DeviceConfig.__doc__,
        )
834
835
836
        device_group.add_argument("--device",
                                  **device_kwargs["device"],
                                  deprecated=True)
837

838
839
840
841
842
843
        # Speculative arguments
        speculative_group = parser.add_argument_group(
            title="SpeculativeConfig",
            description=SpeculativeConfig.__doc__,
        )
        speculative_group.add_argument(
844
            "--speculative-config",
845
846
            type=json.loads,
            default=None,
847
848
            help="The configurations for speculative decoding. Should be a "
            "JSON string.")
zhuwenwen's avatar
zhuwenwen committed
849
850
851
852
853
854
        parser.add_argument(
            '--num-speculative-heads',
            type=int,
            default=EngineArgs.num_speculative_heads,
            help='The number of speculative heads to sample from '
                 'the draft model in speculative decoding.')
855

856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
            **observability_kwargs["show_hidden_metrics_for_version"])
        observability_group.add_argument(
            "--otlp-traces-endpoint",
            **observability_kwargs["otlp_traces_endpoint"])
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
            ",".join(p)
            for p in permutations(get_args(DetailedTraceModules), r=2)
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
            **observability_kwargs["collect_detailed_traces"])
879

880
881
882
883
884
885
886
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
887
            "--max-num-batched-tokens",
888
            **scheduler_kwargs["max_num_batched_tokens"])
889
        scheduler_group.add_argument("--max-num-seqs",
890
891
892
893
894
895
896
                                     **scheduler_kwargs["max_num_seqs"])
        scheduler_group.add_argument(
            "--max-num-partial-prefills",
            **scheduler_kwargs["max_num_partial_prefills"])
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
            **scheduler_kwargs["max_long_partial_prefills"])
897
898
        scheduler_group.add_argument('--cuda-graph-sizes',
                                     **scheduler_kwargs["cuda_graph_sizes"])
899
900
901
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
            **scheduler_kwargs["long_prefill_token_threshold"])
902
        scheduler_group.add_argument("--num-lookahead-slots",
903
                                     **scheduler_kwargs["num_lookahead_slots"])
904
        scheduler_group.add_argument("--scheduler-delay-factor",
905
                                     **scheduler_kwargs["delay_factor"])
906
        scheduler_group.add_argument("--preemption-mode",
907
                                     **scheduler_kwargs["preemption_mode"])
908
        scheduler_group.add_argument("--num-scheduler-steps",
909
                                     **scheduler_kwargs["num_scheduler_steps"])
910
        scheduler_group.add_argument(
911
            "--multi-step-stream-outputs",
912
            **scheduler_kwargs["multi_step_stream_outputs"])
913
        scheduler_group.add_argument("--scheduling-policy",
914
                                     **scheduler_kwargs["policy"])
915
        scheduler_group.add_argument(
916
            "--enable-chunked-prefill",
917
            **scheduler_kwargs["enable_chunked_prefill"])
918
919
920
        scheduler_group.add_argument(
            "--disable-chunked-mm-input",
            **scheduler_kwargs["disable_chunked_mm_input"])
921
922
        scheduler_group.add_argument("--scheduler-cls",
                                     **scheduler_kwargs["scheduler_cls"])
923
924
925
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"])
926
927

        # vLLM arguments
928
        vllm_kwargs = get_kwargs(VllmConfig)
929
930
931
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
932
        )
933
934
935
936
937
938
939
940
        vllm_group.add_argument("--kv-transfer-config",
                                **vllm_kwargs["kv_transfer_config"])
        vllm_group.add_argument('--kv-events-config',
                                **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument("--compilation-config", "-O",
                                **vllm_kwargs["compilation_config"])
        vllm_group.add_argument("--additional-config",
                                **vllm_kwargs["additional_config"])
941

942
943
944
945
        # Other arguments
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
                            default=True,
946
                            deprecated=True,
947
948
949
950
951
952
953
954
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
        parser.add_argument('--disable-log-stats',
                            action='store_true',
                            help='Disable logging statistics.')
955

956
        return parser
957
958

    @classmethod
959
    def from_cli_args(cls, args: argparse.Namespace):
960
961
962
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
963
964
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
965

966
    def create_model_config(self) -> ModelConfig:
967
968
969
970
971
972
973
974
975
976
977
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

978
        return ModelConfig(
979
            model=self.model,
980
            hf_config_path=self.hf_config_path,
981
            task=self.task,
982
            tokenizer=self.tokenizer,
983
984
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
985
            allowed_local_media_path=self.allowed_local_media_path,
986
987
988
989
990
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
991
            rope_theta=self.rope_theta,
992
            hf_token=self.hf_token,
993
            hf_overrides=self.hf_overrides,
994
995
996
997
998
999
1000
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
1001
            disable_cascade_attn=self.disable_cascade_attn,
1002
            skip_tokenizer_init=self.skip_tokenizer_init,
1003
            enable_prompt_embeds=self.enable_prompt_embeds,
1004
            served_model_name=self.served_model_name,
1005
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1006
            media_io_kwargs=self.media_io_kwargs,
1007
            use_async_output_proc=not self.disable_async_output_proc,
1008
            config_format=self.config_format,
1009
            mm_processor_kwargs=self.mm_processor_kwargs,
1010
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1011
1012
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1013
            logits_processor_pattern=self.logits_processor_pattern,
1014
            generation_config=self.generation_config,
1015
            override_generation_config=self.override_generation_config,
1016
            enable_sleep_mode=self.enable_sleep_mode,
1017
            model_impl=self.model_impl,
1018
            override_attention_dtype=self.override_attention_dtype,
1019
            enable_chunked_prefill=self.enable_chunked_prefill,
1020
        )
1021

1022
1023
    def create_load_config(self) -> LoadConfig:

1024
1025
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1026

1027
1028
1029
1030
1031
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1032
            use_tqdm_on_load=self.use_tqdm_on_load,
1033
            pt_load_map_location=self.pt_load_map_location,
1034
1035
        )

1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1049
        dictionary from the engine.
1050
1051
        """
        if self.speculative_config is None:
1052
1053
            return None

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
        speculative_config = SpeculativeConfig.from_dict(
            self.speculative_config)

        return speculative_config

1068
1069
1070
1071
1072
1073
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.
1074

1075
1076
1077
        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
1078

1079
1080
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.
1081

1082
1083
1084
        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1085
1086
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update()
1087

1088
1089
        device_config = DeviceConfig(
            device=cast(Device, current_platform.device_type))
1090
1091
        model_config = self.create_model_config()

1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
1111
            self._set_default_args_v1(usage_context, model_config)
1112
1113
        else:
            self._set_default_args_v0(model_config)
1114

1115
        assert self.enable_chunked_prefill is not None
1116

1117
1118
1119
1120
1121
        if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
            assert self.enforce_eager, (
                "Cuda graph is not supported with DualChunkFlashAttention. "
                "To run the model in eager mode, set 'enforce_eager=True' "
                "or use '--enforce-eager' in the CLI.")
1122
1123
            assert current_platform.is_cuda() or current_platform.is_rocm(), (
                "DualChunkFlashAttention is supported on CUDA/ROCM platform.")
1124
1125
1126
1127
            assert not use_v1, (
                "DualChunkFlashAttention is not supported on V1 engine. "
                "To run the model in V0 engine, try set 'VLLM_USE_V1=0'")

1128
        cache_config = CacheConfig(
1129
            block_size=self.block_size,
1130
1131
1132
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1133
            is_attention_free=model_config.is_attention_free,
1134
1135
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1136
            enable_prefix_caching=self.enable_prefix_caching,
1137
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1138
            cpu_offload_gb=self.cpu_offload_gb,
1139
            calculate_kv_scales=self.calculate_kv_scales,
1140
        )
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
        data_parallel_external_lb = self.data_parallel_rank is not None
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
                "data_parallel_size_local must be 1 when data_parallel_rank "
                "is set")
            data_parallel_size_local = 1
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
        else:
            # Local DP size defaults to global DP size if not set.
            data_parallel_size_local = self.data_parallel_size
1164
1165
1166

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
                    "Using host IP %s as ray-based data parallel address",
                    host_ip)
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
                    self.data_parallel_backend)
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1181
1182
1183
1184
1185
1186
1187

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
        data_parallel_rpc_port = self.data_parallel_rpc_port if (
            self.data_parallel_rpc_port
            is not None) else ParallelConfig.data_parallel_rpc_port

1188
        parallel_config = ParallelConfig(
1189
1190
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1191
            data_parallel_size=self.data_parallel_size,
1192
1193
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1194
1195
1196
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1197
            data_parallel_backend=self.data_parallel_backend,
1198
            enable_expert_parallel=self.enable_expert_parallel,
1199
1200
1201
1202
1203
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.num_redundant_experts,
            eplb_window_size=self.eplb_window_size,
            eplb_step_interval=self.eplb_step_interval,
            eplb_log_balancedness=self.eplb_log_balancedness,
1204
1205
1206
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1207
            placement_group=placement_group,
1208
1209
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1210
            worker_extension_cls=self.worker_extension_cls,
1211
1212
            enable_multimodal_encoder_data_parallel=self.
            enable_multimodal_encoder_data_parallel,
1213
            enable_dp_attention=self.enable_dp_attention,
1214
        )
1215

1216
        speculative_config = self.create_speculative_config(
1217
1218
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1219
            enable_chunked_prefill=self.enable_chunked_prefill,
王敏's avatar
王敏 committed
1220
            disable_log_stats=self.disable_log_stats,
1221
1222
        )

1223
        # Reminder: Please update docs/features/compatibility_matrix.md
1224
        # If the feature combo become valid
1225
1226
1227
1228
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1229
1230
1231
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1232
1233
1234
1235
1236
1237
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1238
1239
1240
1241
1242
1243
1244
1245

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots
1246
1247
        check_incompatible_config(envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT, envs.USE_FUSED_RMS_QUANT)
        
1248
        scheduler_config = SchedulerConfig(
1249
            runner_type=model_config.runner_type,
1250
1251
1252
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1253
            cuda_graph_sizes=self.cuda_graph_sizes,
1254
            num_lookahead_slots=num_lookahead_slots,
1255
1256
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1257
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1258
            is_multimodal_model=model_config.is_multimodal_model,
1259
            preemption_mode=self.preemption_mode,
1260
            num_scheduler_steps=self.num_scheduler_steps,
1261
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1262
1263
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1264
            policy=self.scheduling_policy,
1265
            scheduler_cls=self.scheduler_cls,
1266
1267
1268
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1269
1270
            disable_hybrid_kv_cache_manager=self.
            disable_hybrid_kv_cache_manager,
1271
        )
1272

1273
        lora_config = LoRAConfig(
1274
            bias_enabled=self.enable_lora_bias,
1275
1276
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1277
            fully_sharded_loras=self.fully_sharded_loras,
1278
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1279
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1280
1281
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
1282
1283
            and self.max_cpu_loras > 0 else None,
            lora_target_modules=self.lora_target_modules) if self.enable_lora else None
1284

1285
1286
1287
1288
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1289
        load_config = self.create_load_config()
1290

1291
1292
1293
1294
1295
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1296
        decoding_config = DecodingConfig(
1297
1298
1299
1300
1301
            backend=self.guided_decoding_backend,
            disable_fallback=self.guided_decoding_disable_fallback,
            disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
            disable_additional_properties=\
                self.guided_decoding_disable_additional_properties,
1302
1303
            reasoning_backend=self.reasoning_parser
        )
1304

1305
        observability_config = ObservabilityConfig(
1306
1307
            show_hidden_metrics_for_version=self.
            show_hidden_metrics_for_version,
1308
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1309
            collect_detailed_traces=self.collect_detailed_traces,
1310
        )
1311

1312
        config = VllmConfig(
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1323
            prompt_adapter_config=prompt_adapter_config,
1324
            compilation_config=self.compilation_config,
1325
            kv_transfer_config=self.kv_transfer_config,
1326
            kv_events_config=self.kv_events_config,
1327
            additional_config=self.additional_config,
1328
        )
1329

1330
1331
        return config

1332
1333
1334
1335
1336
1337
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1338
        if self.load_format == LoadFormat.SHARDED_STATE.value:
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

1350
        if self.preemption_mode != SchedulerConfig.preemption_mode:
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

1361
        if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
1362
1363
1364
1365
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

1366
        if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
1367
1368
1369
1370
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

1371
1372
        if self.guided_decoding_backend not in get_args(
                GuidedDecodingBackendV1):
1373
1374
1375
1376
            _raise_or_fallback(
                feature_name=
                f"--guided-decoding-backend={self.guided_decoding_backend}",
                recommend_to_remove=False)
1377
1378
1379
            return False

        # Need at least Ampere for now (FA support required).
1380
1381
1382
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1383
        from vllm.platforms import current_platform
1384
        if (current_platform.is_cuda()
1385
                and current_platform.get_device_capability()
1386
1387
1388
1389
1390
1391
1392
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
1393
1394
1395
1396
1397
1398
            fp8_attention = self.kv_cache_dtype.startswith("fp8")
            will_use_fa = (
                current_platform.is_cuda()
                and not envs.is_set("VLLM_ATTENTION_BACKEND")
            ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
            supported = False
1399
1400
1401
            if current_platform.is_rocm():
                supported = True
            elif fp8_attention and will_use_fa:
1402
                from vllm.attention.utils.fa_utils import (
1403
1404
                    flash_attn_supports_fp8)
                supported = flash_attn_supports_fp8()
xiabo's avatar
xiabo committed
1405
1406
1407
1408
1409

            int8_attention = self.kv_cache_dtype.startswith("int8")
            if int8_attention:
                supported = True
                
1410
1411
1412
1413
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1414
1415
1416
1417
1418
1419
1420

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

1421
1422
1423
1424
1425
1426
        # No text embedding inputs so far.
        if self.enable_prompt_embeds:
            _raise_or_fallback(feature_name="--enable-prompt-embeds",
                               recommend_to_remove=False)
            return False

1427
1428
1429
1430
1431
1432
        # No Mamba or Encoder-Decoder so far.
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

Chen Zhang's avatar
Chen Zhang committed
1433
1434
1435
1436
1437
        # V1 mamba models are unoptimized.
        if model_config.has_inner_state and _warn_or_fallback(
                feature_name="Mamba"):
            return False

1438
1439
        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
1440
                != SchedulerConfig.max_num_partial_prefills
1441
                or self.max_long_partial_prefills
1442
                != SchedulerConfig.max_long_partial_prefills):
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

1453
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1454
        is_ngram_enabled = False
1455
        is_eagle_enabled = False
1456
        is_medusa_enabled = False
1457
        if self.speculative_config is not None:
1458
            # This is supported but experimental (handled below).
1459
1460
1461
1462
            speculative_method = self.speculative_config.get("method")
            if speculative_method:
                if speculative_method in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
1463
1464
                elif speculative_method == "medusa":
                    is_medusa_enabled = True
Jiayi Yao's avatar
Jiayi Yao committed
1465
                elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
1466
                    is_eagle_enabled = True
1467
            else:
1468
1469
1470
                speculative_model = self.speculative_config.get("model")
                if speculative_model in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
1471
            if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
1472
                # Other speculative decoding methods are not supported yet.
1473
1474
1475
1476
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

1477
        # No XFormers so far.
1478
        V1_BACKENDS = [
1479
1480
1481
1482
1483
1484
            "FLASH_ATTN_VLLM_V1",
            "FLASH_ATTN",
            "PALLAS",
            "PALLAS_VLLM_V1",
            "TRITON_ATTN_VLLM_V1",
            "TRITON_MLA",
1485
            "CUTLASS_MLA_VLLM_V1",
1486
1487
1488
            "FLASHMLA",
            "FLASHINFER",
            "FLASHINFER_VLLM_V1",
1489
            "ROCM_AITER_MLA",
1490
            "TORCH_SDPA_VLLM_V1",
1491
            "FLEX_ATTENTION",
1492
1493
1494
1495
1496
1497
1498
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1499
1500
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1501
1502
1503
1504
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1505
1506
1507
        #############################################################
        # Experimental Features - allow users to opt in.

1508
1509
1510
1511
1512
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1513
        if (self.pipeline_parallel_size > 1
1514
                and self.distributed_executor_backend
1515
1516
                not in (ParallelConfig.distributed_executor_backend, "ray",
                        "mp", "external_launcher")):
1517
            name = "Pipeline Parallelism without Ray distributed executor " \
1518
                    "or multiprocessing executor or external launcher"
1519
            _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1520
1521
            return False

1522
1523
1524
1525
        # The platform may be supported on V1, but off by default for now.
        if not current_platform.default_v1(  # noqa: SIM103
                model_config=model_config) and _warn_or_fallback(
                    current_platform.device_name):
1526
            return False
1527
1528
1529
1530
1531
1532
1533

        if (current_platform.is_cpu()
                and model_config.get_sliding_window() is not None):
            _raise_or_fallback(feature_name="sliding window (CPU backend)",
                               recommend_to_remove=False)
            return False

1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                from vllm.platforms import current_platform
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1555
                use_spec_decode = self.speculative_config is not None
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching:
            # Disable prefix caching for multimodal models for VLLM_V0.
            if model_config.is_multimodal_model:
                logger.warning(
                    "--enable-prefix-caching is not supported for multimodal "
                    "models in V0 and has been disabled.")
                self.enable_prefix_caching = False

            # VLLM_V0 only supports builtin hash algo for prefix caching.
1593
            if self.prefix_caching_hash_algo == "sha256":
1594
1595
1596
                raise ValueError(
                    "sha256 is not supported for prefix caching in V0 engine. "
                    "Please use 'builtin'.")
1597
1598
1599
1600
1601

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

1602
1603
    def _set_default_args_v1(self, usage_context: UsageContext,
                             model_config: ModelConfig) -> None:
1604
        """Set Default Arguments for V1 Engine."""
1605

1606
1607
1608
1609
1610
        # V1 always uses chunked prefills and prefix caching
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
1611
1612
1613
            if model_config.enable_chunked_prefill is not None and \
                model_config.enable_chunked_prefill is False:
                self.enable_chunked_prefill = False
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = True
        else:

            pooling_type = model_config.pooler_config.pooling_type

            # TODO: when encoder models are supported we'll have to
            # check for causal attention here.
            incremental_prefill_supported = (pooling_type is not None and
                                             pooling_type.lower() == "last")
1624

1625
1626
            action = "Enabling" if \
                incremental_prefill_supported else "Disabling"
1627
1628
1629
1630
            
            if model_config.enable_chunked_prefill is not None and \
                model_config.enable_chunked_prefill is False:
                self.enable_chunked_prefill = False
1631

1632
1633
1634
1635
1636
1637
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)
1638

1639
1640
        if not self.enable_chunked_prefill:
            self.max_num_batched_tokens = model_config.max_model_len
1641

1642
1643
1644
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1645
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1646

1647
1648
        # When no user override, set the default values based on the usage
        # context.
1649
        # Use different default values for different hardware.
1650
1651
1652
1653
1654
1655

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
1656
        from vllm.platforms import current_platform
1657
        try:
1658
            device_memory = current_platform.get_device_total_memory()
1659
            device_name = current_platform.get_device_name().lower()
1660
1661
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1662
            device_memory = 0
1663

1664
1665
1666
1667
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1668
            # For GPUs like H100 and MI300x, use larger default values.
1669
1670
1671
1672
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1673
1674
1675
1676
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1677
1678
1679
1680
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
1681
                UsageContext.OPENAI_API_SERVER: 10240,
1682
            }
1683
1684
1685
1686
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1687

1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
                    'V6E': 2048,
                    'V5E': 1024,
                    'V5P': 512,
                },
                UsageContext.OPENAI_API_SERVER: {
                    'V6E': 1024,
                    'V5E': 512,
                    'V5P': 256,
                }
            }

1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 4096,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 128,
                UsageContext.OPENAI_API_SERVER: 32,
            }

1714
        use_context_value = usage_context.value if usage_context else None
1715
1716
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
                if chip_name in default_max_num_batched_tokens_tpu[
                        usage_context]:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens_tpu[
                            usage_context][chip_name]
                else:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens[usage_context]
            else:
                self.max_num_batched_tokens = default_max_num_batched_tokens[
                    usage_context]
1730
            logger.debug(
1731
                "Setting max_num_batched_tokens to %d for %s usage context.",
1732
                self.max_num_batched_tokens, use_context_value)
1733

1734
1735
1736
        if (self.max_num_seqs is None
                and usage_context in default_max_num_seqs):
            self.max_num_seqs = default_max_num_seqs[usage_context]
1737
1738
1739

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1740

1741

1742
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1743
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1744
    """Arguments for asynchronous vLLM engine."""
1745
    disable_log_requests: bool = False
1746
1747

    @staticmethod
1748
1749
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1750
1751
1752
1753
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1754
1755
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1756
1757
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1758
                            help='Disable logging requests.')
1759
1760
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update(parser)
1761
        return parser
1762
1763


1764
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
zhuwenwen's avatar
zhuwenwen committed
1765
1766
    # if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    if envs.VLLM_USE_V1:
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
zhuwenwen's avatar
zhuwenwen committed
1778
1779
    # if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
    if envs.VLLM_USE_V1:
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1793
1794
1795
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1796

1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
    match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
    if match:
        decimal_multiplier = {
            'k': 10**3,
            'm': 10**6,
            'g': 10**9,
        }
        binary_multiplier = {
            'K': 2**10,
            'M': 2**20,
            'G': 2**30,
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
                raise argparse.ArgumentTypeError("Decimals are not allowed " \
                f"with binary suffixes like {suffix}. Did you mean to use " \
                f"{number}{suffix.lower()} instead?") from e

    # Regular plain number.
    return int(value)


1834
1835
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1836
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1837
1838
1839


def _async_engine_args_parser():
1840
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
zhuwenwen's avatar
zhuwenwen committed
1841
                                        async_args_only=True)