arg_utils.py 82.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# yapf: disable
5
import argparse
6
import copy
7
import dataclasses
8
import functools
9
import json
10
import sys
11
import threading
12
import warnings
13
from dataclasses import MISSING, dataclass, fields, is_dataclass
14
from itertools import permutations
15
16
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
                    Type, TypeVar, Union, cast, get_args, get_origin)
17

18
import regex as re
19
import torch
20
from pydantic import TypeAdapter, ValidationError
21
from typing_extensions import TypeIs, deprecated
22

23
import vllm.envs as envs
24
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
25
26
27
28
29
30
                         ConfigFormat, ConfigType, DecodingConfig,
                         DetailedTraceModules, Device, DeviceConfig,
                         DistributedExecutorBackend, GuidedDecodingBackend,
                         GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
                         ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
31
32
33
34
35
                         ObservabilityConfig, ParallelConfig, PoolerConfig,
                         PrefixCachingHashAlgo, PromptAdapterConfig,
                         SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
                         TaskOption, TokenizerMode, TokenizerPoolConfig,
                         VllmConfig, get_attr_docs, get_field)
36
from vllm.executor.executor_base import ExecutorBase
37
from vllm.logger import init_logger
38
from vllm.model_executor.layers.quantization import QuantizationMethods
39
from vllm.platforms import CpuArchEnum, current_platform
40
from vllm.plugins import load_general_plugins
41
from vllm.reasoning import ReasoningParserManager
42
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
43
from vllm.transformers_utils.utils import check_gguf_file
44
from vllm.usage.usage_lib import UsageContext
45
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
Rui Qiao's avatar
Rui Qiao committed
46
                        GiB_bytes, get_ip, is_in_ray_actor)
47
48

# yapf: enable
49

50
51
logger = init_logger(__name__)

52
53
54
55
56
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]

57

58
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
59

60
    def _parse_type(val: str) -> T:
61
        try:
62
63
            if return_type is json.loads and not re.match(
                    r"(?s)^\s*{.*}\s*$", val):
64
65
66
67
68
                return cast(T, nullable_kvs(val))
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
                f"Value {val} cannot be converted to {return_type}.") from e
69

70
71
72
73
74
75
76
77
78
79
80
    return _parse_type


def optional_type(
        return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:

    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

81
    return _optional_type
82
83


84
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
85
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
86
        return str(val)
87
    return optional_type(json.loads)(val)
88
89


90
91
92
93
94
95
@deprecated(
    "Passing a JSON argument as a string containing comma separated key=value "
    "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
    "string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
    """Parses a string containing comma separate key [str] to value [int]
96
97
98
99
100
101
102
103
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
104
    out_dict: dict[str, int] = {}
105
    for item in val.split(","):
106
107
108
109
110
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
111
112

        try:
113
            parsed_value = int(value)
114
115
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
116
117
118
119
120
121
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
122
123
124
125

    return out_dict


126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


141
142
143
144
145
146
147
148
149
150
151
152
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
    """Convert Literal type hints to argparse kwargs."""
    type_hint = get_type(type_hints, Literal)
    choices = get_args(type_hint)
    choice_type = type(choices[0])
    if not all(isinstance(choice, choice_type) for choice in choices):
        raise ValueError(
            "All choices must be of the same type. "
            f"Got {choices} with types {[type(c) for c in choices]}")
    return {"type": choice_type, "choices": sorted(choices)}


153
154
155
156
157
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
    elif origin is Union:
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


175
176
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
177
178
179
    cls_docs = get_attr_docs(cls)
    kwargs = {}
    for field in fields(cls):
180
        # Get the set of possible types for the field
181
        type_hints: set[TypeHint] = get_type_hints(field.type)
182
183
184
185
186

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

187
        # Get the default value of the field
188
189
190
        if field.default is not MISSING:
            default = field.default
        elif field.default_factory is not MISSING:
191
            default = field.default_factory()
192
193
194

        # Get the help text for the field
        name = field.name
195
        help = cls_docs[name].strip()
196
197
198
199
200
201
202
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
203
204
205
206
        json_tip = """\n\nShould either be a valid JSON string or JSON keys
        passed individually. For example, the following sets of arguments are
        equivalent:\n\n
        - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
207
208
209
210
        - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n
        Additionally, list elements can be passed individually using '+':
        - `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
        - `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n"""
211
        if dataclass_cls is not None:
212
213
214
215
216
217
218
219
220
221

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    if hasattr(cls, "from_cli"):
                        return cls.from_cli(val)
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
222
223
            kwargs[name]["help"] += json_tip
        elif contains_type(type_hints, bool):
224
225
226
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
227
            kwargs[name].update(literal_to_kwargs(type_hints))
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
                f"type. Got {types}.")
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
            assert len(types) == 1, (
                "List type must have exactly one type. Got "
                f"{type_hint} with types {types}")
            kwargs[name]["type"] = types[0]
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
247
            # Special case for large integers
248
            if name in {"max_model_len", "max_num_batched_tokens"}:
249
                kwargs[name]["type"] = human_readable_int
250
251
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
252
253
254
        elif (contains_type(type_hints, dict)
              and (contains_type(type_hints, str)
                   or any(is_not_builtin(th) for th in type_hints))):
255
            kwargs[name]["type"] = union_dict_and_str
256
        elif contains_type(type_hints, dict):
257
            kwargs[name]["type"] = parse_type(json.loads)
258
            kwargs[name]["help"] += json_tip
259
260
261
262
263
264
265
        elif (contains_type(type_hints, str)
              or any(is_not_builtin(th) for th in type_hints)):
            kwargs[name]["type"] = str
        else:
            raise ValueError(
                f"Unsupported type {type_hints} for argument {name}.")

266
267
268
269
270
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

271
272
273
274
275
276
277
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
278
279


280
281
282
283
284
285
286
287
288
289
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    """Return argparse kwargs for the given Config dataclass.

    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


290
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
291
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
292
    """Arguments for vLLM engine."""
293
294
295
296
297
298
299
    model: str = ModelConfig.model
    served_model_name: Optional[Union[
        str, List[str]]] = ModelConfig.served_model_name
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
    task: TaskOption = ModelConfig.task
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
300
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
301
302
303
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
304
305
    download_dir: Optional[str] = LoadConfig.download_dir
    load_format: str = LoadConfig.load_format
306
307
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
308
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
309
310
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
311
312
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig,
                                            "cuda_graph_sizes")
313
314
315
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
316
    distributed_executor_backend: Optional[Union[
317
318
        DistributedExecutorBackend,
        Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
319
    # number of P/D disaggregation (or other disaggregation) workers
320
321
322
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
    data_parallel_size: int = ParallelConfig.data_parallel_size
323
    data_parallel_rank: Optional[int] = None
324
325
326
    data_parallel_size_local: Optional[int] = None
    data_parallel_address: Optional[str] = None
    data_parallel_rpc_port: Optional[int] = None
Rui Qiao's avatar
Rui Qiao committed
327
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
328
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
329
330
331
332
333
    enable_eplb: bool = ParallelConfig.enable_eplb
    num_redundant_experts: int = ParallelConfig.num_redundant_experts
    eplb_window_size: int = ParallelConfig.eplb_window_size
    eplb_step_interval: int = ParallelConfig.eplb_step_interval
    eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness
334
335
    max_parallel_loading_workers: Optional[
        int] = ParallelConfig.max_parallel_loading_workers
336
337
338
339
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
    prefix_caching_hash_algo: PrefixCachingHashAlgo = \
        CacheConfig.prefix_caching_hash_algo
340
341
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
342
    use_v2_block_manager: bool = True
343
344
345
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
346
347
348
349
350
351
352
    max_num_batched_tokens: Optional[
        int] = SchedulerConfig.max_num_batched_tokens
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
    long_prefill_token_threshold: int = \
        SchedulerConfig.long_prefill_token_threshold
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
353
    max_logprobs: int = ModelConfig.max_logprobs
354
    disable_log_stats: bool = False
355
356
357
358
359
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
360
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
361
362
363
364
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
    max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
365
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
366
367
368
    # The following three fields are deprecated and will be removed in a future
    # release. Setting them will have no effect. Please remove them from your
    # configurations.
369
    tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
370
371
    tokenizer_pool_type: str = TokenizerPoolConfig.pool_type
    tokenizer_pool_extra_config: dict = \
372
        get_field(TokenizerPoolConfig, "extra_config")
373
    limit_mm_per_prompt: dict[str, int] = \
374
        get_field(MultiModalConfig, "limit_per_prompt")
375
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
376
377
378
    media_io_kwargs: dict[str, dict[str,
                                    Any]] = get_field(MultiModalConfig,
                                                      "media_io_kwargs")
379
380
381
382
    mm_processor_kwargs: Optional[Dict[str, Any]] = \
        MultiModalConfig.mm_processor_kwargs
    disable_mm_preprocessor_cache: bool = \
        MultiModalConfig.disable_mm_preprocessor_cache
383
    # LoRA fields
384
    enable_lora: bool = False
385
386
387
388
389
390
391
392
393
394
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
    long_lora_scaling_factors: Optional[tuple[float, ...]] = \
        LoRAConfig.long_lora_scaling_factors
    # PromptAdapter fields
395
    enable_prompt_adapter: bool = False
396
397
398
399
    max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
    max_prompt_adapter_token: int = \
        PromptAdapterConfig.max_prompt_adapter_token

400
    device: Device = DeviceConfig.device
401
402
    num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
    multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
403
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
404
405
    num_gpu_blocks_override: Optional[
        int] = CacheConfig.num_gpu_blocks_override
406
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
407
408
    model_loader_extra_config: dict = \
        get_field(LoadConfig, "model_loader_extra_config")
409
410
    ignore_patterns: Optional[Union[str,
                                    List[str]]] = LoadConfig.ignore_patterns
411
    preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
412

413
414
415
416
    scheduler_delay_factor: float = SchedulerConfig.delay_factor
    enable_chunked_prefill: Optional[
        bool] = SchedulerConfig.enable_chunked_prefill
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
417

418
419
420
    disable_hybrid_kv_cache_manager: bool = (
        SchedulerConfig.disable_hybrid_kv_cache_manager)

421
422
423
424
425
426
    guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
    guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
    guided_decoding_disable_any_whitespace: bool = \
        DecodingConfig.disable_any_whitespace
    guided_decoding_disable_additional_properties: bool = \
        DecodingConfig.disable_additional_properties
427
428
    logits_processor_pattern: Optional[
        str] = ModelConfig.logits_processor_pattern
429

430
    speculative_config: Optional[Dict[str, Any]] = None
431

432
    qlora_adapter_name_or_path: Optional[str] = None
433
434
435
436
437
438
    show_hidden_metrics_for_version: Optional[str] = \
        ObservabilityConfig.show_hidden_metrics_for_version
    otlp_traces_endpoint: Optional[str] = \
        ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: Optional[list[DetailedTraceModules]] = \
        ObservabilityConfig.collect_detailed_traces
439
    disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
440
441
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
442

443
444
445
446
    override_neuron_config: dict[str, Any] = \
        get_field(ModelConfig, "override_neuron_config")
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
        ModelConfig.override_pooler_config
447
448
    compilation_config: CompilationConfig = \
        get_field(VllmConfig, "compilation_config")
449
450
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
451

452
    kv_transfer_config: Optional[KVTransferConfig] = None
453
    kv_events_config: Optional[KVEventsConfig] = None
454

455
456
457
458
459
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
    override_generation_config: dict[str, Any] = \
        get_field(ModelConfig, "override_generation_config")
    model_impl: str = ModelConfig.model_impl
460
    override_attention_dtype: str = ModelConfig.override_attention_dtype
461

462
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
463

464
465
    additional_config: dict[str, Any] = \
        get_field(VllmConfig, "additional_config")
466
467
468
    enable_reasoning: Optional[bool] = None  # DEPRECATED
    reasoning_parser: str = DecodingConfig.reasoning_backend

469
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
470
    pt_load_map_location: str = LoadConfig.pt_load_map_location
471

472
473
474
    enable_multimodal_encoder_data_parallel: bool = \
        ParallelConfig.enable_multimodal_encoder_data_parallel

475
    def __post_init__(self):
476
477
478
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
479
        if isinstance(self.compilation_config, (int, dict)):
480
481
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
482
483
484
485
486
487
488
        if self.qlora_adapter_name_or_path is not None:
            warnings.warn(
                "The `qlora_adapter_name_or_path` is deprecated "
                "and will be removed in v0.10.0. ",
                DeprecationWarning,
                stacklevel=2,
            )
489
        # Setup plugins
490
491
        from vllm.plugins import load_general_plugins
        load_general_plugins()
492
493

    @staticmethod
494
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
495
        """Shared CLI arguments for vLLM engine."""
496

497
        # Model arguments
498
499
500
501
502
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
Reid's avatar
Reid committed
503
        if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]):
504
            model_group.add_argument("--model", **model_kwargs["model"])
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
        model_group.add_argument("--task", **model_kwargs["task"])
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode",
                                 **model_kwargs["tokenizer_mode"])
        model_group.add_argument("--trust-remote-code",
                                 **model_kwargs["trust_remote_code"])
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
        model_group.add_argument("--hf-config-path",
                                 **model_kwargs["hf_config_path"])
        model_group.add_argument("--allowed-local-media-path",
                                 **model_kwargs["allowed_local_media_path"])
        model_group.add_argument("--revision", **model_kwargs["revision"])
        model_group.add_argument("--code-revision",
                                 **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling",
                                 **model_kwargs["rope_scaling"])
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
        model_group.add_argument("--tokenizer-revision",
                                 **model_kwargs["tokenizer_revision"])
        model_group.add_argument("--max-model-len",
                                 **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q",
                                 **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager",
                                 **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-seq-len-to-capture",
                                 **model_kwargs["max_seq_len_to_capture"])
        model_group.add_argument("--max-logprobs",
                                 **model_kwargs["max_logprobs"])
        model_group.add_argument("--disable-sliding-window",
                                 **model_kwargs["disable_sliding_window"])
        model_group.add_argument("--disable-cascade-attn",
                                 **model_kwargs["disable_cascade_attn"])
        model_group.add_argument("--skip-tokenizer-init",
                                 **model_kwargs["skip_tokenizer_init"])
541
542
        model_group.add_argument("--enable-prompt-embeds",
                                 **model_kwargs["enable_prompt_embeds"])
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        model_group.add_argument("--served-model-name",
                                 **model_kwargs["served_model_name"])
        # This one is a special case because it is the
        # opposite of ModelConfig.use_async_output_proc
        model_group.add_argument(
            "--disable-async-output-proc",
            action="store_true",
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
        model_group.add_argument("--config-format",
                                 choices=[f.value for f in ConfigFormat],
                                 **model_kwargs["config_format"])
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
        model_group.add_argument("--hf-token",
                                 type=str,
                                 nargs="?",
                                 const=True,
                                 default=model_kwargs["hf_token"]["default"],
                                 help=model_kwargs["hf_token"]["help"])
        model_group.add_argument("--hf-overrides",
                                 **model_kwargs["hf_overrides"])
        model_group.add_argument("--override-neuron-config",
                                 **model_kwargs["override_neuron_config"])
        model_group.add_argument("--override-pooler-config",
                                 **model_kwargs["override_pooler_config"])
        model_group.add_argument("--logits-processor-pattern",
                                 **model_kwargs["logits_processor_pattern"])
        model_group.add_argument("--generation-config",
                                 **model_kwargs["generation_config"])
        model_group.add_argument("--override-generation-config",
                                 **model_kwargs["override_generation_config"])
        model_group.add_argument("--enable-sleep-mode",
                                 **model_kwargs["enable_sleep_mode"])
        model_group.add_argument("--model-impl",
                                 choices=[f.value for f in ModelImpl],
                                 **model_kwargs["model_impl"])
581
582
        model_group.add_argument("--override-attention-dtype",
                                 **model_kwargs["override_attention_dtype"])
583

584
585
586
587
588
589
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
590
        load_group.add_argument("--load-format",
591
592
                                choices=[f.value for f in LoadFormat],
                                **load_kwargs["load_format"])
593
        load_group.add_argument("--download-dir",
594
                                **load_kwargs["download_dir"])
595
        load_group.add_argument("--model-loader-extra-config",
596
                                **load_kwargs["model_loader_extra_config"])
597
598
599
        load_group.add_argument("--ignore-patterns",
                                **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load",
600
                                **load_kwargs["use_tqdm_on_load"])
601
602
603
604
605
606
607
608
        load_group.add_argument(
            "--qlora-adapter-name-or-path",
            type=str,
            default=None,
            help="The `--qlora-adapter-name-or-path` has no effect, do not set"
            " it, and it  will be removed in v0.10.0.",
            deprecated=True,
        )
609
610
        load_group.add_argument('--pt-load-map-location',
                                **load_kwargs["pt_load_map_location"])
611

612
613
614
615
616
617
        # Guided decoding arguments
        guided_decoding_kwargs = get_kwargs(DecodingConfig)
        guided_decoding_group = parser.add_argument_group(
            title="DecodingConfig",
            description=DecodingConfig.__doc__,
        )
618
619
        guided_decoding_group.add_argument("--guided-decoding-backend",
                                           **guided_decoding_kwargs["backend"])
620
        guided_decoding_group.add_argument(
621
622
623
624
625
626
627
628
            "--guided-decoding-disable-fallback",
            **guided_decoding_kwargs["disable_fallback"])
        guided_decoding_group.add_argument(
            "--guided-decoding-disable-any-whitespace",
            **guided_decoding_kwargs["disable_any_whitespace"])
        guided_decoding_group.add_argument(
            "--guided-decoding-disable-additional-properties",
            **guided_decoding_kwargs["disable_additional_properties"])
629
630
631
        guided_decoding_group.add_argument(
            "--enable-reasoning",
            action=argparse.BooleanOptionalAction,
632
            deprecated=True,
633
            help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as "
634
            "of v0.9.0. Use `--reasoning-parser` to specify the reasoning "
635
            "parser backend instead. This flag (`--enable-reasoning`) will be "
636
637
            "removed in v0.10.0. When `--reasoning-parser` is specified, "
            "reasoning mode is automatically enabled.")
638
639
640
641
642
643
        guided_decoding_group.add_argument(
            "--reasoning-parser",
            # This choices is a special case because it's not static
            choices=list(ReasoningParserManager.reasoning_parsers),
            **guided_decoding_kwargs["reasoning_backend"])

644
        # Parallel arguments
645
646
647
648
649
650
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
651
            "--distributed-executor-backend",
652
653
            **parallel_kwargs["distributed_executor_backend"])
        parallel_group.add_argument(
654
            "--pipeline-parallel-size", "-pp",
655
            **parallel_kwargs["pipeline_parallel_size"])
656
        parallel_group.add_argument("--tensor-parallel-size", "-tp",
657
                                    **parallel_kwargs["tensor_parallel_size"])
658
        parallel_group.add_argument("--data-parallel-size", "-dp",
659
                                    **parallel_kwargs["data_parallel_size"])
660
661
662
663
664
665
        parallel_group.add_argument(
            '--data-parallel-rank',
            '-dpn',
            type=int,
            help='Data parallel rank of this instance. '
            'When set, enables external load balancer mode.')
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
        parallel_group.add_argument('--data-parallel-size-local',
                                    '-dpl',
                                    type=int,
                                    help='Number of data parallel replicas '
                                    'to run on this node.')
        parallel_group.add_argument('--data-parallel-address',
                                    '-dpa',
                                    type=str,
                                    help='Address of data parallel cluster '
                                    'head-node.')
        parallel_group.add_argument('--data-parallel-rpc-port',
                                    '-dpp',
                                    type=int,
                                    help='Port for data parallel RPC '
                                    'communication.')
Rui Qiao's avatar
Rui Qiao committed
681
682
683
684
685
686
        parallel_group.add_argument('--data-parallel-backend',
                                    '-dpb',
                                    type=str,
                                    default='mp',
                                    help='Backend for data parallel, either '
                                    '"mp" or "ray".')
687
        parallel_group.add_argument(
688
            "--enable-expert-parallel",
689
            **parallel_kwargs["enable_expert_parallel"])
690
691
692
693
694
695
696
697
698
699
        parallel_group.add_argument("--enable-eplb",
                                    **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--num-redundant-experts",
                                    **parallel_kwargs["num_redundant_experts"])
        parallel_group.add_argument("--eplb-window-size",
                                    **parallel_kwargs["eplb_window_size"])
        parallel_group.add_argument("--eplb-step-interval",
                                    **parallel_kwargs["eplb_step_interval"])
        parallel_group.add_argument("--eplb-log-balancedness",
                                    **parallel_kwargs["eplb_log_balancedness"])
700
        parallel_group.add_argument(
701
            "--max-parallel-loading-workers",
702
703
            **parallel_kwargs["max_parallel_loading_workers"])
        parallel_group.add_argument(
704
            "--ray-workers-use-nsight",
705
706
            **parallel_kwargs["ray_workers_use_nsight"])
        parallel_group.add_argument(
707
            "--disable-custom-all-reduce",
708
            **parallel_kwargs["disable_custom_all_reduce"])
709
710
711
712
        parallel_group.add_argument("--worker-cls",
                                    **parallel_kwargs["worker_cls"])
        parallel_group.add_argument("--worker-extension-cls",
                                    **parallel_kwargs["worker_extension_cls"])
713
714
715
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
            **parallel_kwargs["enable_multimodal_encoder_data_parallel"])
716

717
718
719
720
721
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
722
        )
723
724
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
        cache_group.add_argument("--gpu-memory-utilization",
725
                                 **cache_kwargs["gpu_memory_utilization"])
726
727
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
        cache_group.add_argument("--kv-cache-dtype",
728
                                 **cache_kwargs["cache_dtype"])
729
        cache_group.add_argument("--num-gpu-blocks-override",
730
731
732
733
734
                                 **cache_kwargs["num_gpu_blocks_override"])
        cache_group.add_argument("--enable-prefix-caching",
                                 **cache_kwargs["enable_prefix_caching"])
        cache_group.add_argument("--prefix-caching-hash-algo",
                                 **cache_kwargs["prefix_caching_hash_algo"])
735
        cache_group.add_argument("--cpu-offload-gb",
736
                                 **cache_kwargs["cpu_offload_gb"])
737
        cache_group.add_argument("--calculate-kv-scales",
738
739
                                 **cache_kwargs["calculate_kv_scales"])

740
741
742
743
744
745
        # Tokenizer arguments
        tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
        tokenizer_group = parser.add_argument_group(
            title="TokenizerPoolConfig",
            description=TokenizerPoolConfig.__doc__,
        )
746
        tokenizer_group.add_argument("--tokenizer-pool-size",
747
                                     **tokenizer_kwargs["pool_size"])
748
        tokenizer_group.add_argument("--tokenizer-pool-type",
749
                                     **tokenizer_kwargs["pool_type"])
750
        tokenizer_group.add_argument("--tokenizer-pool-extra-config",
751
                                     **tokenizer_kwargs["extra_config"])
752
753

        # Multimodal related configs
754
755
756
757
758
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
759
        multimodal_group.add_argument("--limit-mm-per-prompt",
760
                                      **multimodal_kwargs["limit_per_prompt"])
761
762
        multimodal_group.add_argument("--media-io-kwargs",
                                      **multimodal_kwargs["media_io_kwargs"])
763
        multimodal_group.add_argument(
764
            "--mm-processor-kwargs",
765
766
            **multimodal_kwargs["mm_processor_kwargs"])
        multimodal_group.add_argument(
767
            "--disable-mm-preprocessor-cache",
768
            **multimodal_kwargs["disable_mm_preprocessor_cache"])
769
770
771
        multimodal_group.add_argument(
            "--interleave-mm-strings",
            **multimodal_kwargs["interleave_mm_strings"])
772

773
        # LoRA related configs
774
775
776
777
778
779
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
780
            "--enable-lora",
781
            action=argparse.BooleanOptionalAction,
782
783
            help="If True, enable handling of LoRA adapters.")
        lora_group.add_argument("--enable-lora-bias",
784
                                **lora_kwargs["bias_enabled"])
785
786
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
        lora_group.add_argument("--max-lora-rank",
787
                                **lora_kwargs["max_lora_rank"])
788
        lora_group.add_argument("--lora-extra-vocab-size",
789
790
                                **lora_kwargs["lora_extra_vocab_size"])
        lora_group.add_argument(
791
            "--lora-dtype",
792
793
            **lora_kwargs["lora_dtype"],
        )
794
        lora_group.add_argument("--long-lora-scaling-factors",
795
                                **lora_kwargs["long_lora_scaling_factors"])
796
        lora_group.add_argument("--max-cpu-loras",
797
                                **lora_kwargs["max_cpu_loras"])
798
        lora_group.add_argument("--fully-sharded-loras",
799
800
801
802
803
804
805
806
807
                                **lora_kwargs["fully_sharded_loras"])

        # PromptAdapter related configs
        prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
        prompt_adapter_group = parser.add_argument_group(
            title="PromptAdapterConfig",
            description=PromptAdapterConfig.__doc__,
        )
        prompt_adapter_group.add_argument(
808
            "--enable-prompt-adapter",
809
            action=argparse.BooleanOptionalAction,
810
            help="If True, enable handling of PromptAdapters.")
811
        prompt_adapter_group.add_argument(
812
            "--max-prompt-adapters",
813
814
            **prompt_adapter_kwargs["max_prompt_adapters"])
        prompt_adapter_group.add_argument(
815
            "--max-prompt-adapter-token",
816
            **prompt_adapter_kwargs["max_prompt_adapter_token"])
817
818
819
820
821
822
823

        # Device arguments
        device_kwargs = get_kwargs(DeviceConfig)
        device_group = parser.add_argument_group(
            title="DeviceConfig",
            description=DeviceConfig.__doc__,
        )
824
825
826
        device_group.add_argument("--device",
                                  **device_kwargs["device"],
                                  deprecated=True)
827

828
829
830
831
832
833
        # Speculative arguments
        speculative_group = parser.add_argument_group(
            title="SpeculativeConfig",
            description=SpeculativeConfig.__doc__,
        )
        speculative_group.add_argument(
834
            "--speculative-config",
835
836
            type=json.loads,
            default=None,
837
838
            help="The configurations for speculative decoding. Should be a "
            "JSON string.")
839

840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
            **observability_kwargs["show_hidden_metrics_for_version"])
        observability_group.add_argument(
            "--otlp-traces-endpoint",
            **observability_kwargs["otlp_traces_endpoint"])
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
            ",".join(p)
            for p in permutations(get_args(DetailedTraceModules), r=2)
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
            **observability_kwargs["collect_detailed_traces"])
863

864
865
866
867
868
869
870
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
871
            "--max-num-batched-tokens",
872
            **scheduler_kwargs["max_num_batched_tokens"])
873
        scheduler_group.add_argument("--max-num-seqs",
874
875
876
877
878
879
880
                                     **scheduler_kwargs["max_num_seqs"])
        scheduler_group.add_argument(
            "--max-num-partial-prefills",
            **scheduler_kwargs["max_num_partial_prefills"])
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
            **scheduler_kwargs["max_long_partial_prefills"])
881
882
        scheduler_group.add_argument('--cuda-graph-sizes',
                                     **scheduler_kwargs["cuda_graph_sizes"])
883
884
885
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
            **scheduler_kwargs["long_prefill_token_threshold"])
886
        scheduler_group.add_argument("--num-lookahead-slots",
887
                                     **scheduler_kwargs["num_lookahead_slots"])
888
        scheduler_group.add_argument("--scheduler-delay-factor",
889
                                     **scheduler_kwargs["delay_factor"])
890
        scheduler_group.add_argument("--preemption-mode",
891
                                     **scheduler_kwargs["preemption_mode"])
892
        scheduler_group.add_argument("--num-scheduler-steps",
893
                                     **scheduler_kwargs["num_scheduler_steps"])
894
        scheduler_group.add_argument(
895
            "--multi-step-stream-outputs",
896
            **scheduler_kwargs["multi_step_stream_outputs"])
897
        scheduler_group.add_argument("--scheduling-policy",
898
                                     **scheduler_kwargs["policy"])
899
        scheduler_group.add_argument(
900
            "--enable-chunked-prefill",
901
            **scheduler_kwargs["enable_chunked_prefill"])
902
903
904
        scheduler_group.add_argument(
            "--disable-chunked-mm-input",
            **scheduler_kwargs["disable_chunked_mm_input"])
905
906
        scheduler_group.add_argument("--scheduler-cls",
                                     **scheduler_kwargs["scheduler_cls"])
907
908
909
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"])
910
911

        # vLLM arguments
912
        vllm_kwargs = get_kwargs(VllmConfig)
913
914
915
916
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
917
918
919
920
921
922
923
924
        vllm_group.add_argument("--kv-transfer-config",
                                **vllm_kwargs["kv_transfer_config"])
        vllm_group.add_argument('--kv-events-config',
                                **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument("--compilation-config", "-O",
                                **vllm_kwargs["compilation_config"])
        vllm_group.add_argument("--additional-config",
                                **vllm_kwargs["additional_config"])
925

926
927
928
929
        # Other arguments
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
                            default=True,
930
                            deprecated=True,
931
932
933
934
935
936
937
938
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
        parser.add_argument('--disable-log-stats',
                            action='store_true',
                            help='Disable logging statistics.')
939

940
        return parser
941
942

    @classmethod
943
    def from_cli_args(cls, args: argparse.Namespace):
944
945
946
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
947
948
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
949

950
    def create_model_config(self) -> ModelConfig:
951
952
953
954
955
956
957
958
959
960
961
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

962
        return ModelConfig(
963
            model=self.model,
964
            hf_config_path=self.hf_config_path,
965
            task=self.task,
966
            tokenizer=self.tokenizer,
967
968
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
969
            allowed_local_media_path=self.allowed_local_media_path,
970
971
972
973
974
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
975
            rope_theta=self.rope_theta,
976
            hf_token=self.hf_token,
977
            hf_overrides=self.hf_overrides,
978
979
980
981
982
983
984
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
985
            disable_cascade_attn=self.disable_cascade_attn,
986
            skip_tokenizer_init=self.skip_tokenizer_init,
987
            enable_prompt_embeds=self.enable_prompt_embeds,
988
            served_model_name=self.served_model_name,
989
            limit_mm_per_prompt=self.limit_mm_per_prompt,
990
            interleave_mm_strings=self.interleave_mm_strings,
991
            media_io_kwargs=self.media_io_kwargs,
992
            use_async_output_proc=not self.disable_async_output_proc,
993
            config_format=self.config_format,
994
            mm_processor_kwargs=self.mm_processor_kwargs,
995
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
996
997
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
998
            logits_processor_pattern=self.logits_processor_pattern,
999
            generation_config=self.generation_config,
1000
            override_generation_config=self.override_generation_config,
1001
            enable_sleep_mode=self.enable_sleep_mode,
1002
            model_impl=self.model_impl,
1003
            override_attention_dtype=self.override_attention_dtype,
1004
        )
1005

1006
1007
1008
1009
1010
1011
1012
    def validate_tensorizer_args(self):
        from vllm.model_executor.model_loader.tensorizer import (
            TensorizerConfig)
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
                self.model_loader_extra_config["tensorizer_config"][
                    key] = self.model_loader_extra_config[key]
1013

1014
1015
    def create_load_config(self) -> LoadConfig:

1016
1017
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1018

1019
1020
1021
1022
1023
1024
1025
1026
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
                    self.model_loader_extra_config.to_serializable())
            self.model_loader_extra_config["tensorizer_config"] = {}
            self.model_loader_extra_config["tensorizer_config"][
                "tensorizer_dir"] = self.model
            self.validate_tensorizer_args()
1027

1028
1029
1030
1031
1032
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1033
            use_tqdm_on_load=self.use_tqdm_on_load,
1034
            pt_load_map_location=self.pt_load_map_location,
1035
        )
1036

1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1050
        dictionary from the engine.
1051
1052
        """
        if self.speculative_config is None:
1053
1054
            return None

1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
        speculative_config = SpeculativeConfig.from_dict(
            self.speculative_config)

        return speculative_config

1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1079

1080
1081
1082
1083
1084
1085
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1086
        current_platform.pre_register_and_update()
1087

1088
1089
        device_config = DeviceConfig(
            device=cast(Device, current_platform.device_type))
1090
1091
        model_config = self.create_model_config()

1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
1111
            self._set_default_args_v1(usage_context, model_config)
1112
1113
1114
1115
1116
1117
1118
1119
            # Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1
            if current_platform.is_cpu(
            ) and current_platform.get_cpu_architecture() in (
                    CpuArchEnum.POWERPC, CpuArchEnum.ARM):
                logger.info(
                    "Chunked prefill is not supported for ARM and POWER CPUs; "
                    "disabling it for V1 backend.")
                self.enable_chunked_prefill = False
1120
1121
        else:
            self._set_default_args_v0(model_config)
1122
1123
        assert self.enable_chunked_prefill is not None

1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
        if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
            assert self.enforce_eager, (
                "Cuda graph is not supported with DualChunkFlashAttention. "
                "To run the model in eager mode, set 'enforce_eager=True' "
                "or use '--enforce-eager' in the CLI.")
            assert current_platform.is_cuda(), (
                "DualChunkFlashAttention is only supported on CUDA platform.")
            assert not use_v1, (
                "DualChunkFlashAttention is not supported on V1 engine. "
                "To run the model in V0 engine, try set 'VLLM_USE_V1=0'")

1135
        cache_config = CacheConfig(
1136
            block_size=self.block_size,
1137
1138
1139
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1140
            is_attention_free=model_config.is_attention_free,
1141
1142
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1143
            enable_prefix_caching=self.enable_prefix_caching,
1144
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1145
            cpu_offload_gb=self.cpu_offload_gb,
1146
            calculate_kv_scales=self.calculate_kv_scales,
1147
        )
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        data_parallel_external_lb = self.data_parallel_rank is not None
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
                "data_parallel_size_local must be 1 when data_parallel_rank "
                "is set")
            data_parallel_size_local = 1
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
        else:
            # Local DP size defaults to global DP size if not set.
            data_parallel_size_local = self.data_parallel_size
1171
1172
1173

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
                    "Using host IP %s as ray-based data parallel address",
                    host_ip)
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
                    self.data_parallel_backend)
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1188
1189
1190
1191
1192
1193
1194

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
        data_parallel_rpc_port = self.data_parallel_rpc_port if (
            self.data_parallel_rpc_port
            is not None) else ParallelConfig.data_parallel_rpc_port

1195
        parallel_config = ParallelConfig(
1196
1197
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1198
            data_parallel_size=self.data_parallel_size,
1199
1200
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1201
1202
1203
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1204
            data_parallel_backend=self.data_parallel_backend,
1205
            enable_expert_parallel=self.enable_expert_parallel,
1206
1207
1208
1209
1210
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.num_redundant_experts,
            eplb_window_size=self.eplb_window_size,
            eplb_step_interval=self.eplb_step_interval,
            eplb_log_balancedness=self.eplb_log_balancedness,
1211
1212
1213
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1214
            placement_group=placement_group,
1215
1216
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1217
            worker_extension_cls=self.worker_extension_cls,
1218
1219
            enable_multimodal_encoder_data_parallel=self.
            enable_multimodal_encoder_data_parallel,
1220
        )
1221

1222
        speculative_config = self.create_speculative_config(
1223
1224
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1225
            enable_chunked_prefill=self.enable_chunked_prefill,
1226
            disable_log_stats=self.disable_log_stats,
1227
1228
        )

1229
        # Reminder: Please update docs/features/compatibility_matrix.md
1230
        # If the feature combo become valid
1231
1232
1233
1234
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1235
1236
1237
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1238
1239
1240
1241
1242
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1243
1244
1245
1246
1247
1248
1249
1250
1251

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1252
        scheduler_config = SchedulerConfig(
1253
            runner_type=model_config.runner_type,
1254
1255
1256
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1257
            cuda_graph_sizes=self.cuda_graph_sizes,
1258
            num_lookahead_slots=num_lookahead_slots,
1259
1260
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1261
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1262
            is_multimodal_model=model_config.is_multimodal_model,
1263
            preemption_mode=self.preemption_mode,
1264
            num_scheduler_steps=self.num_scheduler_steps,
1265
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1266
1267
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1268
            policy=self.scheduling_policy,
1269
            scheduler_cls=self.scheduler_cls,
1270
1271
1272
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1273
1274
            disable_hybrid_kv_cache_manager=self.
            disable_hybrid_kv_cache_manager,
1275
        )
1276

1277
        lora_config = LoRAConfig(
1278
            bias_enabled=self.enable_lora_bias,
1279
1280
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1281
            fully_sharded_loras=self.fully_sharded_loras,
1282
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1283
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1284
1285
1286
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1287

1288
1289
1290
1291
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1292
        load_config = self.create_load_config()
1293

1294
1295
1296
1297
1298
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1299
        decoding_config = DecodingConfig(
1300
1301
1302
1303
1304
            backend=self.guided_decoding_backend,
            disable_fallback=self.guided_decoding_disable_fallback,
            disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
            disable_additional_properties=\
                self.guided_decoding_disable_additional_properties,
1305
1306
            reasoning_backend=self.reasoning_parser
        )
1307

1308
        observability_config = ObservabilityConfig(
1309
1310
            show_hidden_metrics_for_version=self.
            show_hidden_metrics_for_version,
1311
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1312
            collect_detailed_traces=self.collect_detailed_traces,
1313
        )
1314

1315
        config = VllmConfig(
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1326
            prompt_adapter_config=prompt_adapter_config,
1327
            compilation_config=self.compilation_config,
1328
            kv_transfer_config=self.kv_transfer_config,
1329
            kv_events_config=self.kv_events_config,
1330
            additional_config=self.additional_config,
1331
        )
1332

1333
1334
        return config

1335
1336
1337
1338
1339
1340
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1341
        if self.load_format == LoadFormat.SHARDED_STATE.value:
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

1353
        if self.preemption_mode != SchedulerConfig.preemption_mode:
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

1364
        if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
1365
1366
1367
1368
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

1369
        if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
1370
1371
1372
1373
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

1374
1375
        if self.guided_decoding_backend not in get_args(
                GuidedDecodingBackendV1):
1376
1377
1378
1379
            _raise_or_fallback(
                feature_name=
                f"--guided-decoding-backend={self.guided_decoding_backend}",
                recommend_to_remove=False)
1380
1381
1382
            return False

        # Need at least Ampere for now (FA support required).
1383
1384
1385
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1386
        if (current_platform.is_cuda()
1387
                and current_platform.get_device_capability()
1388
1389
1390
1391
1392
1393
1394
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
1395
1396
1397
1398
1399
1400
            fp8_attention = self.kv_cache_dtype.startswith("fp8")
            will_use_fa = (
                current_platform.is_cuda()
                and not envs.is_set("VLLM_ATTENTION_BACKEND")
            ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
            supported = False
1401
1402
1403
            if current_platform.is_rocm():
                supported = True
            elif fp8_attention and will_use_fa:
1404
                from vllm.attention.utils.fa_utils import (
1405
1406
1407
1408
1409
1410
                    flash_attn_supports_fp8)
                supported = flash_attn_supports_fp8()
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1411
1412
1413
1414
1415
1416
1417

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

1418
1419
1420
1421
1422
1423
        # No text embedding inputs so far.
        if self.enable_prompt_embeds:
            _raise_or_fallback(feature_name="--enable-prompt-embeds",
                               recommend_to_remove=False)
            return False

1424
        # No Mamba or Encoder-Decoder so far.
1425
1426
1427
1428
1429
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

Chen Zhang's avatar
Chen Zhang committed
1430
1431
1432
1433
1434
        # V1 mamba models are unoptimized.
        if model_config.has_inner_state and _warn_or_fallback(
                feature_name="Mamba"):
            return False

1435
1436
        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
1437
                != SchedulerConfig.max_num_partial_prefills
1438
                or self.max_long_partial_prefills
1439
                != SchedulerConfig.max_long_partial_prefills):
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

1450
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1451
        is_ngram_enabled = False
1452
        is_eagle_enabled = False
1453
        is_medusa_enabled = False
1454
        if self.speculative_config is not None:
1455
            # This is supported but experimental (handled below).
1456
1457
1458
1459
            speculative_method = self.speculative_config.get("method")
            if speculative_method:
                if speculative_method in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
1460
1461
                elif speculative_method == "medusa":
                    is_medusa_enabled = True
Jiayi Yao's avatar
Jiayi Yao committed
1462
                elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
1463
                    is_eagle_enabled = True
1464
            else:
1465
1466
1467
                speculative_model = self.speculative_config.get("model")
                if speculative_model in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
1468
            if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
1469
                # Other speculative decoding methods are not supported yet.
1470
1471
1472
1473
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

1474
        # No XFormers so far.
1475
        V1_BACKENDS = [
1476
1477
1478
1479
1480
1481
            "FLASH_ATTN_VLLM_V1",
            "FLASH_ATTN",
            "PALLAS",
            "PALLAS_VLLM_V1",
            "TRITON_ATTN_VLLM_V1",
            "TRITON_MLA",
1482
            "CUTLASS_MLA_VLLM_V1",
1483
1484
1485
            "FLASHMLA",
            "FLASHINFER",
            "FLASHINFER_VLLM_V1",
1486
            "ROCM_AITER_MLA",
1487
            "TORCH_SDPA_VLLM_V1",
1488
            "FLEX_ATTENTION",
1489
1490
1491
1492
1493
1494
1495
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1496
1497
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1498
1499
1500
1501
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1502
1503
1504
        #############################################################
        # Experimental Features - allow users to opt in.

1505
1506
1507
1508
1509
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1510
        if (self.pipeline_parallel_size > 1
1511
                and self.distributed_executor_backend
1512
1513
                not in (ParallelConfig.distributed_executor_backend, "ray",
                        "mp", "external_launcher")):
1514
            name = "Pipeline Parallelism without Ray distributed executor " \
1515
                    "or multiprocessing executor or external launcher"
1516
            _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1517
1518
            return False

1519
1520
1521
1522
        # The platform may be supported on V1, but off by default for now.
        if not current_platform.default_v1(  # noqa: SIM103
                model_config=model_config) and _warn_or_fallback(
                    current_platform.device_name):
1523
            return False
1524
1525
1526
1527
1528
1529
1530

        if (current_platform.is_cpu()
                and model_config.get_sliding_window() is not None):
            _raise_or_fallback(feature_name="sliding window (CPU backend)",
                               recommend_to_remove=False)
            return False

1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                from vllm.platforms import current_platform
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1552
                use_spec_decode = self.speculative_config is not None
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching:
            # Disable prefix caching for multimodal models for VLLM_V0.
            if model_config.is_multimodal_model:
                logger.warning(
                    "--enable-prefix-caching is not supported for multimodal "
                    "models in V0 and has been disabled.")
                self.enable_prefix_caching = False

            # VLLM_V0 only supports builtin hash algo for prefix caching.
1590
            if self.prefix_caching_hash_algo == "sha256":
1591
1592
1593
                raise ValueError(
                    "sha256 is not supported for prefix caching in V0 engine. "
                    "Please use 'builtin'.")
1594
1595
1596
1597
1598

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

1599
1600
    def _set_default_args_v1(self, usage_context: UsageContext,
                             model_config: ModelConfig) -> None:
1601
        """Set Default Arguments for V1 Engine."""
1602

1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
        # V1 always uses chunked prefills and prefix caching
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = True
        else:

            pooling_type = model_config.pooler_config.pooling_type

            # TODO: when encoder models are supported we'll have to
            # check for causal attention here.
            incremental_prefill_supported = (pooling_type is not None and
                                             pooling_type.lower() == "last")
1618

1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
            action = "Enabling" if \
                incremental_prefill_supported else "Disabling"

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)

        if not self.enable_chunked_prefill:
            self.max_num_batched_tokens = model_config.max_model_len
1631

1632
1633
1634
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1635
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1636

1637
1638
        # When no user override, set the default values based on the usage
        # context.
1639
        # Use different default values for different hardware.
1640
1641
1642
1643
1644
1645
1646

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1647
            device_memory = current_platform.get_device_total_memory()
1648
            device_name = current_platform.get_device_name().lower()
1649
1650
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1651
            device_memory = 0
1652

1653
1654
1655
1656
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1657
            # For GPUs like H100 and MI300x, use larger default values.
1658
1659
1660
1661
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1662
1663
1664
1665
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1666
1667
1668
1669
1670
1671
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1672
1673
1674
1675
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1676

1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
                    'V6E': 2048,
                    'V5E': 1024,
                    'V5P': 512,
                },
                UsageContext.OPENAI_API_SERVER: {
                    'V6E': 1024,
                    'V5E': 512,
                    'V5P': 256,
                }
            }

1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 4096,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 128,
                UsageContext.OPENAI_API_SERVER: 32,
            }

1703
        use_context_value = usage_context.value if usage_context else None
1704
1705
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
                if chip_name in default_max_num_batched_tokens_tpu[
                        usage_context]:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens_tpu[
                            usage_context][chip_name]
                else:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens[usage_context]
            else:
                self.max_num_batched_tokens = default_max_num_batched_tokens[
                    usage_context]
1719
            logger.debug(
1720
                "Setting max_num_batched_tokens to %d for %s usage context.",
1721
                self.max_num_batched_tokens, use_context_value)
1722

1723
1724
1725
        if (self.max_num_seqs is None
                and usage_context in default_max_num_seqs):
            self.max_num_seqs = default_max_num_seqs[usage_context]
1726
1727
1728

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1729

1730

1731
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1732
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1733
    """Arguments for asynchronous vLLM engine."""
1734
    disable_log_requests: bool = False
1735
1736

    @staticmethod
1737
1738
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1739
1740
1741
1742
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1743
1744
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1745
1746
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1747
                            help='Disable logging requests.')
1748
        current_platform.pre_register_and_update(parser)
1749
        return parser
1750
1751


1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1779
1780
1781
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1782

1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
    match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
    if match:
        decimal_multiplier = {
            'k': 10**3,
            'm': 10**6,
            'g': 10**9,
        }
        binary_multiplier = {
            'K': 2**10,
            'M': 2**20,
            'G': 2**30,
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
                raise argparse.ArgumentTypeError("Decimals are not allowed " \
                f"with binary suffixes like {suffix}. Did you mean to use " \
                f"{number}{suffix.lower()} instead?") from e

    # Regular plain number.
    return int(value)


1820
1821
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1822
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1823
1824
1825


def _async_engine_args_parser():
1826
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1827
                                        async_args_only=True)