arg_utils.py 83.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# yapf: disable
5
import argparse
6
import copy
7
import dataclasses
8
import functools
9
import json
10
import sys
11
import threading
12
import warnings
13
from dataclasses import MISSING, dataclass, fields, is_dataclass
14
from itertools import permutations
15
16
17
from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List,
                    Literal, Optional, Type, TypeVar, Union, cast, get_args,
                    get_origin)
18

19
import regex as re
20
import torch
21
from pydantic import TypeAdapter, ValidationError
22
from typing_extensions import TypeIs, deprecated
23

24
import vllm.envs as envs
25
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
26
27
28
29
30
31
                         ConfigFormat, ConfigType, DecodingConfig,
                         DetailedTraceModules, Device, DeviceConfig,
                         DistributedExecutorBackend, GuidedDecodingBackend,
                         GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
                         ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
32
33
34
                         ObservabilityConfig, ParallelConfig, PoolerConfig,
                         PrefixCachingHashAlgo, PromptAdapterConfig,
                         SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
35
36
                         TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
                         get_field)
37
from vllm.logger import init_logger
38
from vllm.platforms import CpuArchEnum, current_platform
39
from vllm.plugins import load_general_plugins
40
from vllm.reasoning import ReasoningParserManager
41
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
42
from vllm.transformers_utils.utils import check_gguf_file
43
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
Rui Qiao's avatar
Rui Qiao committed
44
                        GiB_bytes, get_ip, is_in_ray_actor)
45
46

# yapf: enable
47

48
49
50
51
52
53
54
55
56
if TYPE_CHECKING:
    from vllm.executor.executor_base import ExecutorBase
    from vllm.model_executor.layers.quantization import QuantizationMethods
    from vllm.usage.usage_lib import UsageContext
else:
    ExecutorBase = Any
    QuantizationMethods = Any
    UsageContext = Any

57
58
logger = init_logger(__name__)

59
60
61
62
63
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]

64

65
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
66

67
    def _parse_type(val: str) -> T:
68
        try:
69
70
            if return_type is json.loads and not re.match(
                    r"(?s)^\s*{.*}\s*$", val):
71
72
73
74
75
                return cast(T, nullable_kvs(val))
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
                f"Value {val} cannot be converted to {return_type}.") from e
76

77
78
79
80
81
82
83
84
85
86
87
    return _parse_type


def optional_type(
        return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:

    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

88
    return _optional_type
89
90


91
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
92
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
93
        return str(val)
94
    return optional_type(json.loads)(val)
95
96


97
98
99
100
101
102
@deprecated(
    "Passing a JSON argument as a string containing comma separated key=value "
    "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
    "string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
    """Parses a string containing comma separate key [str] to value [int]
103
104
105
106
107
108
109
110
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
111
    out_dict: dict[str, int] = {}
112
    for item in val.split(","):
113
114
115
116
117
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
118
119

        try:
120
            parsed_value = int(value)
121
122
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
123
124
125
126
127
128
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
129
130
131
132

    return out_dict


133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


148
149
150
151
152
153
154
155
156
157
158
159
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
    """Convert Literal type hints to argparse kwargs."""
    type_hint = get_type(type_hints, Literal)
    choices = get_args(type_hint)
    choice_type = type(choices[0])
    if not all(isinstance(choice, choice_type) for choice in choices):
        raise ValueError(
            "All choices must be of the same type. "
            f"Got {choices} with types {[type(c) for c in choices]}")
    return {"type": choice_type, "choices": sorted(choices)}


160
161
162
163
164
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
    elif origin is Union:
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


182
183
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
184
185
186
    cls_docs = get_attr_docs(cls)
    kwargs = {}
    for field in fields(cls):
187
        # Get the set of possible types for the field
188
        type_hints: set[TypeHint] = get_type_hints(field.type)
189
190
191
192
193

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

194
        # Get the default value of the field
195
196
197
        if field.default is not MISSING:
            default = field.default
        elif field.default_factory is not MISSING:
198
            default = field.default_factory()
199
200
201

        # Get the help text for the field
        name = field.name
202
        help = cls_docs[name].strip()
203
204
205
206
207
208
209
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
210
211
212
213
214
215
216
217
218
219
220
        json_tip = """Should either be a valid JSON string or JSON keys
passed individually. For example, the following sets of arguments are
equivalent:

- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`

Additionally, list elements can be passed individually using `+`:

- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`"""
221
        if dataclass_cls is not None:
222
223
224
225
226
227
228
229
230
231

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    if hasattr(cls, "from_cli"):
                        return cls.from_cli(val)
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
232
            kwargs[name]["help"] += f"\n\n{json_tip}"
233
        elif contains_type(type_hints, bool):
234
235
236
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
237
            kwargs[name].update(literal_to_kwargs(type_hints))
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
                f"type. Got {types}.")
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
            assert len(types) == 1, (
                "List type must have exactly one type. Got "
                f"{type_hint} with types {types}")
            kwargs[name]["type"] = types[0]
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
257
            # Special case for large integers
258
            if name in {"max_model_len", "max_num_batched_tokens"}:
259
                kwargs[name]["type"] = human_readable_int
260
261
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
262
263
264
        elif (contains_type(type_hints, dict)
              and (contains_type(type_hints, str)
                   or any(is_not_builtin(th) for th in type_hints))):
265
            kwargs[name]["type"] = union_dict_and_str
266
        elif contains_type(type_hints, dict):
267
            kwargs[name]["type"] = parse_type(json.loads)
268
            kwargs[name]["help"] += f"\n\n{json_tip}"
269
270
271
272
273
274
275
        elif (contains_type(type_hints, str)
              or any(is_not_builtin(th) for th in type_hints)):
            kwargs[name]["type"] = str
        else:
            raise ValueError(
                f"Unsupported type {type_hints} for argument {name}.")

276
277
278
279
280
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

281
282
283
284
285
286
287
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
288
289


290
291
292
293
294
295
296
297
298
299
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    """Return argparse kwargs for the given Config dataclass.

    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


300
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
301
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
302
    """Arguments for vLLM engine."""
303
304
305
306
307
308
309
    model: str = ModelConfig.model
    served_model_name: Optional[Union[
        str, List[str]]] = ModelConfig.served_model_name
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
    task: TaskOption = ModelConfig.task
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
310
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
311
312
313
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
314
315
    download_dir: Optional[str] = LoadConfig.download_dir
    load_format: str = LoadConfig.load_format
316
317
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
318
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
319
320
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
321
322
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig,
                                            "cuda_graph_sizes")
323
324
325
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
326
    distributed_executor_backend: Optional[Union[
327
328
        DistributedExecutorBackend,
        Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
329
    # number of P/D disaggregation (or other disaggregation) workers
330
331
332
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
    data_parallel_size: int = ParallelConfig.data_parallel_size
333
    data_parallel_rank: Optional[int] = None
334
335
336
    data_parallel_size_local: Optional[int] = None
    data_parallel_address: Optional[str] = None
    data_parallel_rpc_port: Optional[int] = None
Rui Qiao's avatar
Rui Qiao committed
337
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
338
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
339
340
341
342
343
    enable_eplb: bool = ParallelConfig.enable_eplb
    num_redundant_experts: int = ParallelConfig.num_redundant_experts
    eplb_window_size: int = ParallelConfig.eplb_window_size
    eplb_step_interval: int = ParallelConfig.eplb_step_interval
    eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness
344
345
    max_parallel_loading_workers: Optional[
        int] = ParallelConfig.max_parallel_loading_workers
346
347
348
349
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
    prefix_caching_hash_algo: PrefixCachingHashAlgo = \
        CacheConfig.prefix_caching_hash_algo
350
351
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
352
    use_v2_block_manager: bool = True
353
354
355
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
356
357
358
359
360
361
362
    max_num_batched_tokens: Optional[
        int] = SchedulerConfig.max_num_batched_tokens
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
    long_prefill_token_threshold: int = \
        SchedulerConfig.long_prefill_token_threshold
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
363
    max_logprobs: int = ModelConfig.max_logprobs
364
    disable_log_stats: bool = False
365
366
367
368
369
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
370
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
371
372
373
374
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
    max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
375
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
376
    limit_mm_per_prompt: dict[str, int] = \
377
        get_field(MultiModalConfig, "limit_per_prompt")
378
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
379
380
381
    media_io_kwargs: dict[str, dict[str,
                                    Any]] = get_field(MultiModalConfig,
                                                      "media_io_kwargs")
382
383
384
385
    mm_processor_kwargs: Optional[Dict[str, Any]] = \
        MultiModalConfig.mm_processor_kwargs
    disable_mm_preprocessor_cache: bool = \
        MultiModalConfig.disable_mm_preprocessor_cache
386
    # LoRA fields
387
    enable_lora: bool = False
388
389
390
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
391
392
    default_mm_loras: Optional[Dict[str, str]] = \
        LoRAConfig.default_mm_loras
393
394
395
396
397
398
399
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
    long_lora_scaling_factors: Optional[tuple[float, ...]] = \
        LoRAConfig.long_lora_scaling_factors
    # PromptAdapter fields
400
    enable_prompt_adapter: bool = False
401
402
403
404
    max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
    max_prompt_adapter_token: int = \
        PromptAdapterConfig.max_prompt_adapter_token

405
    device: Device = DeviceConfig.device
406
407
    num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
    multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
408
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
409
410
    num_gpu_blocks_override: Optional[
        int] = CacheConfig.num_gpu_blocks_override
411
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
412
413
    model_loader_extra_config: dict = \
        get_field(LoadConfig, "model_loader_extra_config")
414
415
    ignore_patterns: Optional[Union[str,
                                    List[str]]] = LoadConfig.ignore_patterns
416
    preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
417

418
419
420
421
    scheduler_delay_factor: float = SchedulerConfig.delay_factor
    enable_chunked_prefill: Optional[
        bool] = SchedulerConfig.enable_chunked_prefill
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
422

423
424
425
    disable_hybrid_kv_cache_manager: bool = (
        SchedulerConfig.disable_hybrid_kv_cache_manager)

426
427
428
429
430
431
    guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
    guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
    guided_decoding_disable_any_whitespace: bool = \
        DecodingConfig.disable_any_whitespace
    guided_decoding_disable_additional_properties: bool = \
        DecodingConfig.disable_additional_properties
432
433
    logits_processor_pattern: Optional[
        str] = ModelConfig.logits_processor_pattern
434

435
    speculative_config: Optional[Dict[str, Any]] = None
436

437
    qlora_adapter_name_or_path: Optional[str] = None
438
439
440
441
442
443
    show_hidden_metrics_for_version: Optional[str] = \
        ObservabilityConfig.show_hidden_metrics_for_version
    otlp_traces_endpoint: Optional[str] = \
        ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: Optional[list[DetailedTraceModules]] = \
        ObservabilityConfig.collect_detailed_traces
444
    disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
445
446
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
447

448
449
450
451
    override_neuron_config: dict[str, Any] = \
        get_field(ModelConfig, "override_neuron_config")
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
        ModelConfig.override_pooler_config
452
453
    compilation_config: CompilationConfig = \
        get_field(VllmConfig, "compilation_config")
454
455
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
456

457
    kv_transfer_config: Optional[KVTransferConfig] = None
458
    kv_events_config: Optional[KVEventsConfig] = None
459

460
461
462
463
464
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
    override_generation_config: dict[str, Any] = \
        get_field(ModelConfig, "override_generation_config")
    model_impl: str = ModelConfig.model_impl
465
    override_attention_dtype: str = ModelConfig.override_attention_dtype
466

467
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
468

469
470
    additional_config: dict[str, Any] = \
        get_field(VllmConfig, "additional_config")
471
472
473
    enable_reasoning: Optional[bool] = None  # DEPRECATED
    reasoning_parser: str = DecodingConfig.reasoning_backend

474
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
475
    pt_load_map_location: str = LoadConfig.pt_load_map_location
476

477
478
479
    enable_multimodal_encoder_data_parallel: bool = \
        ParallelConfig.enable_multimodal_encoder_data_parallel

480
481
    async_scheduling: bool = SchedulerConfig.async_scheduling

482
    def __post_init__(self):
483
484
485
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
486
        if isinstance(self.compilation_config, (int, dict)):
487
488
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
489
490
491
492
493
494
495
        if self.qlora_adapter_name_or_path is not None:
            warnings.warn(
                "The `qlora_adapter_name_or_path` is deprecated "
                "and will be removed in v0.10.0. ",
                DeprecationWarning,
                stacklevel=2,
            )
496
        # Setup plugins
497
498
        from vllm.plugins import load_general_plugins
        load_general_plugins()
499
500

    @staticmethod
501
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
502
        """Shared CLI arguments for vLLM engine."""
503

504
        # Model arguments
505
506
507
508
509
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
Reid's avatar
Reid committed
510
        if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]):
511
            model_group.add_argument("--model", **model_kwargs["model"])
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        model_group.add_argument("--task", **model_kwargs["task"])
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode",
                                 **model_kwargs["tokenizer_mode"])
        model_group.add_argument("--trust-remote-code",
                                 **model_kwargs["trust_remote_code"])
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
        model_group.add_argument("--hf-config-path",
                                 **model_kwargs["hf_config_path"])
        model_group.add_argument("--allowed-local-media-path",
                                 **model_kwargs["allowed_local_media_path"])
        model_group.add_argument("--revision", **model_kwargs["revision"])
        model_group.add_argument("--code-revision",
                                 **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling",
                                 **model_kwargs["rope_scaling"])
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
        model_group.add_argument("--tokenizer-revision",
                                 **model_kwargs["tokenizer_revision"])
        model_group.add_argument("--max-model-len",
                                 **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q",
                                 **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager",
                                 **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-seq-len-to-capture",
                                 **model_kwargs["max_seq_len_to_capture"])
        model_group.add_argument("--max-logprobs",
                                 **model_kwargs["max_logprobs"])
        model_group.add_argument("--disable-sliding-window",
                                 **model_kwargs["disable_sliding_window"])
        model_group.add_argument("--disable-cascade-attn",
                                 **model_kwargs["disable_cascade_attn"])
        model_group.add_argument("--skip-tokenizer-init",
                                 **model_kwargs["skip_tokenizer_init"])
548
549
        model_group.add_argument("--enable-prompt-embeds",
                                 **model_kwargs["enable_prompt_embeds"])
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
        model_group.add_argument("--served-model-name",
                                 **model_kwargs["served_model_name"])
        # This one is a special case because it is the
        # opposite of ModelConfig.use_async_output_proc
        model_group.add_argument(
            "--disable-async-output-proc",
            action="store_true",
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
        model_group.add_argument("--config-format",
                                 choices=[f.value for f in ConfigFormat],
                                 **model_kwargs["config_format"])
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
        model_group.add_argument("--hf-token",
                                 type=str,
                                 nargs="?",
                                 const=True,
                                 default=model_kwargs["hf_token"]["default"],
                                 help=model_kwargs["hf_token"]["help"])
        model_group.add_argument("--hf-overrides",
                                 **model_kwargs["hf_overrides"])
        model_group.add_argument("--override-neuron-config",
                                 **model_kwargs["override_neuron_config"])
        model_group.add_argument("--override-pooler-config",
                                 **model_kwargs["override_pooler_config"])
        model_group.add_argument("--logits-processor-pattern",
                                 **model_kwargs["logits_processor_pattern"])
        model_group.add_argument("--generation-config",
                                 **model_kwargs["generation_config"])
        model_group.add_argument("--override-generation-config",
                                 **model_kwargs["override_generation_config"])
        model_group.add_argument("--enable-sleep-mode",
                                 **model_kwargs["enable_sleep_mode"])
        model_group.add_argument("--model-impl",
                                 choices=[f.value for f in ModelImpl],
                                 **model_kwargs["model_impl"])
588
589
        model_group.add_argument("--override-attention-dtype",
                                 **model_kwargs["override_attention_dtype"])
590

591
592
593
594
595
596
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
597
        load_group.add_argument("--load-format",
598
599
                                choices=[f.value for f in LoadFormat],
                                **load_kwargs["load_format"])
600
        load_group.add_argument("--download-dir",
601
                                **load_kwargs["download_dir"])
602
        load_group.add_argument("--model-loader-extra-config",
603
                                **load_kwargs["model_loader_extra_config"])
604
605
606
        load_group.add_argument("--ignore-patterns",
                                **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load",
607
                                **load_kwargs["use_tqdm_on_load"])
608
609
610
611
612
613
614
615
        load_group.add_argument(
            "--qlora-adapter-name-or-path",
            type=str,
            default=None,
            help="The `--qlora-adapter-name-or-path` has no effect, do not set"
            " it, and it  will be removed in v0.10.0.",
            deprecated=True,
        )
616
617
        load_group.add_argument('--pt-load-map-location',
                                **load_kwargs["pt_load_map_location"])
618

619
620
621
622
623
624
        # Guided decoding arguments
        guided_decoding_kwargs = get_kwargs(DecodingConfig)
        guided_decoding_group = parser.add_argument_group(
            title="DecodingConfig",
            description=DecodingConfig.__doc__,
        )
625
626
        guided_decoding_group.add_argument("--guided-decoding-backend",
                                           **guided_decoding_kwargs["backend"])
627
        guided_decoding_group.add_argument(
628
629
630
631
632
633
634
635
            "--guided-decoding-disable-fallback",
            **guided_decoding_kwargs["disable_fallback"])
        guided_decoding_group.add_argument(
            "--guided-decoding-disable-any-whitespace",
            **guided_decoding_kwargs["disable_any_whitespace"])
        guided_decoding_group.add_argument(
            "--guided-decoding-disable-additional-properties",
            **guided_decoding_kwargs["disable_additional_properties"])
636
637
638
        guided_decoding_group.add_argument(
            "--enable-reasoning",
            action=argparse.BooleanOptionalAction,
639
            deprecated=True,
640
            help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as "
641
            "of v0.9.0. Use `--reasoning-parser` to specify the reasoning "
642
            "parser backend instead. This flag (`--enable-reasoning`) will be "
643
644
            "removed in v0.10.0. When `--reasoning-parser` is specified, "
            "reasoning mode is automatically enabled.")
645
646
647
648
649
650
        guided_decoding_group.add_argument(
            "--reasoning-parser",
            # This choices is a special case because it's not static
            choices=list(ReasoningParserManager.reasoning_parsers),
            **guided_decoding_kwargs["reasoning_backend"])

651
        # Parallel arguments
652
653
654
655
656
657
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
658
            "--distributed-executor-backend",
659
660
            **parallel_kwargs["distributed_executor_backend"])
        parallel_group.add_argument(
661
            "--pipeline-parallel-size", "-pp",
662
            **parallel_kwargs["pipeline_parallel_size"])
663
        parallel_group.add_argument("--tensor-parallel-size", "-tp",
664
                                    **parallel_kwargs["tensor_parallel_size"])
665
        parallel_group.add_argument("--data-parallel-size", "-dp",
666
                                    **parallel_kwargs["data_parallel_size"])
667
668
669
670
671
672
        parallel_group.add_argument(
            '--data-parallel-rank',
            '-dpn',
            type=int,
            help='Data parallel rank of this instance. '
            'When set, enables external load balancer mode.')
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
        parallel_group.add_argument('--data-parallel-size-local',
                                    '-dpl',
                                    type=int,
                                    help='Number of data parallel replicas '
                                    'to run on this node.')
        parallel_group.add_argument('--data-parallel-address',
                                    '-dpa',
                                    type=str,
                                    help='Address of data parallel cluster '
                                    'head-node.')
        parallel_group.add_argument('--data-parallel-rpc-port',
                                    '-dpp',
                                    type=int,
                                    help='Port for data parallel RPC '
                                    'communication.')
Rui Qiao's avatar
Rui Qiao committed
688
689
690
691
692
693
        parallel_group.add_argument('--data-parallel-backend',
                                    '-dpb',
                                    type=str,
                                    default='mp',
                                    help='Backend for data parallel, either '
                                    '"mp" or "ray".')
694
        parallel_group.add_argument(
695
            "--enable-expert-parallel",
696
            **parallel_kwargs["enable_expert_parallel"])
697
698
699
700
701
702
703
704
705
706
        parallel_group.add_argument("--enable-eplb",
                                    **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--num-redundant-experts",
                                    **parallel_kwargs["num_redundant_experts"])
        parallel_group.add_argument("--eplb-window-size",
                                    **parallel_kwargs["eplb_window_size"])
        parallel_group.add_argument("--eplb-step-interval",
                                    **parallel_kwargs["eplb_step_interval"])
        parallel_group.add_argument("--eplb-log-balancedness",
                                    **parallel_kwargs["eplb_log_balancedness"])
707
        parallel_group.add_argument(
708
            "--max-parallel-loading-workers",
709
710
            **parallel_kwargs["max_parallel_loading_workers"])
        parallel_group.add_argument(
711
            "--ray-workers-use-nsight",
712
713
            **parallel_kwargs["ray_workers_use_nsight"])
        parallel_group.add_argument(
714
            "--disable-custom-all-reduce",
715
            **parallel_kwargs["disable_custom_all_reduce"])
716
717
718
719
        parallel_group.add_argument("--worker-cls",
                                    **parallel_kwargs["worker_cls"])
        parallel_group.add_argument("--worker-extension-cls",
                                    **parallel_kwargs["worker_extension_cls"])
720
721
722
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
            **parallel_kwargs["enable_multimodal_encoder_data_parallel"])
723

724
725
726
727
728
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
729
        )
730
731
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
        cache_group.add_argument("--gpu-memory-utilization",
732
                                 **cache_kwargs["gpu_memory_utilization"])
733
734
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
        cache_group.add_argument("--kv-cache-dtype",
735
                                 **cache_kwargs["cache_dtype"])
736
        cache_group.add_argument("--num-gpu-blocks-override",
737
738
739
740
741
                                 **cache_kwargs["num_gpu_blocks_override"])
        cache_group.add_argument("--enable-prefix-caching",
                                 **cache_kwargs["enable_prefix_caching"])
        cache_group.add_argument("--prefix-caching-hash-algo",
                                 **cache_kwargs["prefix_caching_hash_algo"])
742
        cache_group.add_argument("--cpu-offload-gb",
743
                                 **cache_kwargs["cpu_offload_gb"])
744
        cache_group.add_argument("--calculate-kv-scales",
745
746
                                 **cache_kwargs["calculate_kv_scales"])

747
        # Multimodal related configs
748
749
750
751
752
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
753
        multimodal_group.add_argument("--limit-mm-per-prompt",
754
                                      **multimodal_kwargs["limit_per_prompt"])
755
756
        multimodal_group.add_argument("--media-io-kwargs",
                                      **multimodal_kwargs["media_io_kwargs"])
757
        multimodal_group.add_argument(
758
            "--mm-processor-kwargs",
759
760
            **multimodal_kwargs["mm_processor_kwargs"])
        multimodal_group.add_argument(
761
            "--disable-mm-preprocessor-cache",
762
            **multimodal_kwargs["disable_mm_preprocessor_cache"])
763
764
765
        multimodal_group.add_argument(
            "--interleave-mm-strings",
            **multimodal_kwargs["interleave_mm_strings"])
766

767
        # LoRA related configs
768
769
770
771
772
773
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
774
            "--enable-lora",
775
            action=argparse.BooleanOptionalAction,
776
777
            help="If True, enable handling of LoRA adapters.")
        lora_group.add_argument("--enable-lora-bias",
778
                                **lora_kwargs["bias_enabled"])
779
780
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
        lora_group.add_argument("--max-lora-rank",
781
                                **lora_kwargs["max_lora_rank"])
782
        lora_group.add_argument("--lora-extra-vocab-size",
783
784
                                **lora_kwargs["lora_extra_vocab_size"])
        lora_group.add_argument(
785
            "--lora-dtype",
786
787
            **lora_kwargs["lora_dtype"],
        )
788
        lora_group.add_argument("--long-lora-scaling-factors",
789
                                **lora_kwargs["long_lora_scaling_factors"])
790
        lora_group.add_argument("--max-cpu-loras",
791
                                **lora_kwargs["max_cpu_loras"])
792
        lora_group.add_argument("--fully-sharded-loras",
793
                                **lora_kwargs["fully_sharded_loras"])
794
795
        lora_group.add_argument("--default-mm-loras",
                                **lora_kwargs["default_mm_loras"])
796
797
798
799
800
801
802
803

        # PromptAdapter related configs
        prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
        prompt_adapter_group = parser.add_argument_group(
            title="PromptAdapterConfig",
            description=PromptAdapterConfig.__doc__,
        )
        prompt_adapter_group.add_argument(
804
            "--enable-prompt-adapter",
805
            action=argparse.BooleanOptionalAction,
806
            help="If True, enable handling of PromptAdapters.")
807
        prompt_adapter_group.add_argument(
808
            "--max-prompt-adapters",
809
810
            **prompt_adapter_kwargs["max_prompt_adapters"])
        prompt_adapter_group.add_argument(
811
            "--max-prompt-adapter-token",
812
            **prompt_adapter_kwargs["max_prompt_adapter_token"])
813
814
815
816
817
818
819

        # Device arguments
        device_kwargs = get_kwargs(DeviceConfig)
        device_group = parser.add_argument_group(
            title="DeviceConfig",
            description=DeviceConfig.__doc__,
        )
820
821
822
        device_group.add_argument("--device",
                                  **device_kwargs["device"],
                                  deprecated=True)
823

824
825
826
827
828
829
        # Speculative arguments
        speculative_group = parser.add_argument_group(
            title="SpeculativeConfig",
            description=SpeculativeConfig.__doc__,
        )
        speculative_group.add_argument(
830
            "--speculative-config",
831
832
            type=json.loads,
            default=None,
833
834
            help="The configurations for speculative decoding. Should be a "
            "JSON string.")
835

836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
            **observability_kwargs["show_hidden_metrics_for_version"])
        observability_group.add_argument(
            "--otlp-traces-endpoint",
            **observability_kwargs["otlp_traces_endpoint"])
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
            ",".join(p)
            for p in permutations(get_args(DetailedTraceModules), r=2)
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
            **observability_kwargs["collect_detailed_traces"])
859

860
861
862
863
864
865
866
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
867
            "--max-num-batched-tokens",
868
            **scheduler_kwargs["max_num_batched_tokens"])
869
        scheduler_group.add_argument("--max-num-seqs",
870
871
872
873
874
875
876
                                     **scheduler_kwargs["max_num_seqs"])
        scheduler_group.add_argument(
            "--max-num-partial-prefills",
            **scheduler_kwargs["max_num_partial_prefills"])
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
            **scheduler_kwargs["max_long_partial_prefills"])
877
878
        scheduler_group.add_argument('--cuda-graph-sizes',
                                     **scheduler_kwargs["cuda_graph_sizes"])
879
880
881
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
            **scheduler_kwargs["long_prefill_token_threshold"])
882
        scheduler_group.add_argument("--num-lookahead-slots",
883
                                     **scheduler_kwargs["num_lookahead_slots"])
884
        scheduler_group.add_argument("--scheduler-delay-factor",
885
                                     **scheduler_kwargs["delay_factor"])
886
        scheduler_group.add_argument("--preemption-mode",
887
                                     **scheduler_kwargs["preemption_mode"])
888
        scheduler_group.add_argument("--num-scheduler-steps",
889
                                     **scheduler_kwargs["num_scheduler_steps"])
890
        scheduler_group.add_argument(
891
            "--multi-step-stream-outputs",
892
            **scheduler_kwargs["multi_step_stream_outputs"])
893
        scheduler_group.add_argument("--scheduling-policy",
894
                                     **scheduler_kwargs["policy"])
895
        scheduler_group.add_argument(
896
            "--enable-chunked-prefill",
897
            **scheduler_kwargs["enable_chunked_prefill"])
898
899
900
        scheduler_group.add_argument(
            "--disable-chunked-mm-input",
            **scheduler_kwargs["disable_chunked_mm_input"])
901
902
        scheduler_group.add_argument("--scheduler-cls",
                                     **scheduler_kwargs["scheduler_cls"])
903
904
905
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"])
906
907
        scheduler_group.add_argument("--async-scheduling",
                                     **scheduler_kwargs["async_scheduling"])
908
909

        # vLLM arguments
910
        vllm_kwargs = get_kwargs(VllmConfig)
911
912
913
914
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
915
916
917
918
919
920
921
922
        vllm_group.add_argument("--kv-transfer-config",
                                **vllm_kwargs["kv_transfer_config"])
        vllm_group.add_argument('--kv-events-config',
                                **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument("--compilation-config", "-O",
                                **vllm_kwargs["compilation_config"])
        vllm_group.add_argument("--additional-config",
                                **vllm_kwargs["additional_config"])
923

924
925
926
927
        # Other arguments
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
                            default=True,
928
                            deprecated=True,
929
930
931
932
933
934
935
936
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
        parser.add_argument('--disable-log-stats',
                            action='store_true',
                            help='Disable logging statistics.')
937

938
        return parser
939
940

    @classmethod
941
    def from_cli_args(cls, args: argparse.Namespace):
942
943
944
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
945
946
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
947

948
    def create_model_config(self) -> ModelConfig:
949
950
951
952
953
954
955
956
957
958
959
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

960
        return ModelConfig(
961
            model=self.model,
962
            hf_config_path=self.hf_config_path,
963
            task=self.task,
964
            tokenizer=self.tokenizer,
965
966
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
967
            allowed_local_media_path=self.allowed_local_media_path,
968
969
970
971
972
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
973
            rope_theta=self.rope_theta,
974
            hf_token=self.hf_token,
975
            hf_overrides=self.hf_overrides,
976
977
978
979
980
981
982
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
983
            disable_cascade_attn=self.disable_cascade_attn,
984
            skip_tokenizer_init=self.skip_tokenizer_init,
985
            enable_prompt_embeds=self.enable_prompt_embeds,
986
            served_model_name=self.served_model_name,
987
            limit_mm_per_prompt=self.limit_mm_per_prompt,
988
            interleave_mm_strings=self.interleave_mm_strings,
989
            media_io_kwargs=self.media_io_kwargs,
990
            use_async_output_proc=not self.disable_async_output_proc,
991
            config_format=self.config_format,
992
            mm_processor_kwargs=self.mm_processor_kwargs,
993
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
994
995
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
996
            logits_processor_pattern=self.logits_processor_pattern,
997
            generation_config=self.generation_config,
998
            override_generation_config=self.override_generation_config,
999
            enable_sleep_mode=self.enable_sleep_mode,
1000
            model_impl=self.model_impl,
1001
            override_attention_dtype=self.override_attention_dtype,
1002
        )
1003

1004
1005
1006
1007
1008
1009
1010
    def validate_tensorizer_args(self):
        from vllm.model_executor.model_loader.tensorizer import (
            TensorizerConfig)
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
                self.model_loader_extra_config["tensorizer_config"][
                    key] = self.model_loader_extra_config[key]
1011

1012
1013
    def create_load_config(self) -> LoadConfig:

1014
1015
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1016

1017
1018
1019
1020
1021
1022
1023
1024
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
                    self.model_loader_extra_config.to_serializable())
            self.model_loader_extra_config["tensorizer_config"] = {}
            self.model_loader_extra_config["tensorizer_config"][
                "tensorizer_dir"] = self.model
            self.validate_tensorizer_args()
1025

1026
1027
1028
1029
1030
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1031
            use_tqdm_on_load=self.use_tqdm_on_load,
1032
            pt_load_map_location=self.pt_load_map_location,
1033
        )
1034

1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1048
        dictionary from the engine.
1049
1050
        """
        if self.speculative_config is None:
1051
1052
            return None

1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
        speculative_config = SpeculativeConfig.from_dict(
            self.speculative_config)

        return speculative_config

1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1077

1078
1079
1080
1081
1082
1083
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1084
        current_platform.pre_register_and_update()
1085

1086
1087
        device_config = DeviceConfig(
            device=cast(Device, current_platform.device_type))
1088
1089
        model_config = self.create_model_config()

1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
1109
            self._set_default_args_v1(usage_context, model_config)
1110
1111
1112
1113
1114
1115
1116
1117
            # Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1
            if current_platform.is_cpu(
            ) and current_platform.get_cpu_architecture() in (
                    CpuArchEnum.POWERPC, CpuArchEnum.ARM):
                logger.info(
                    "Chunked prefill is not supported for ARM and POWER CPUs; "
                    "disabling it for V1 backend.")
                self.enable_chunked_prefill = False
1118
1119
        else:
            self._set_default_args_v0(model_config)
1120
1121
        assert self.enable_chunked_prefill is not None

1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
        if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
            assert self.enforce_eager, (
                "Cuda graph is not supported with DualChunkFlashAttention. "
                "To run the model in eager mode, set 'enforce_eager=True' "
                "or use '--enforce-eager' in the CLI.")
            assert current_platform.is_cuda(), (
                "DualChunkFlashAttention is only supported on CUDA platform.")
            assert not use_v1, (
                "DualChunkFlashAttention is not supported on V1 engine. "
                "To run the model in V0 engine, try set 'VLLM_USE_V1=0'")

1133
        cache_config = CacheConfig(
1134
            block_size=self.block_size,
1135
1136
1137
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1138
            is_attention_free=model_config.is_attention_free,
1139
1140
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1141
            enable_prefix_caching=self.enable_prefix_caching,
1142
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1143
            cpu_offload_gb=self.cpu_offload_gb,
1144
            calculate_kv_scales=self.calculate_kv_scales,
1145
        )
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
        data_parallel_external_lb = self.data_parallel_rank is not None
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
                "data_parallel_size_local must be 1 when data_parallel_rank "
                "is set")
            data_parallel_size_local = 1
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
        else:
            # Local DP size defaults to global DP size if not set.
            data_parallel_size_local = self.data_parallel_size
1169
1170
1171

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
                    "Using host IP %s as ray-based data parallel address",
                    host_ip)
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
                    self.data_parallel_backend)
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1186
1187
1188
1189
1190
1191
1192

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
        data_parallel_rpc_port = self.data_parallel_rpc_port if (
            self.data_parallel_rpc_port
            is not None) else ParallelConfig.data_parallel_rpc_port

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
        if self.async_scheduling:
            # Async scheduling does not work with the uniprocess backend.
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "mp"
                logger.info("Using mp-based distributed executor backend "
                            "for async scheduling.")
            if self.distributed_executor_backend == "uni":
                raise ValueError("Async scheduling is not supported with "
                                 "uni-process backend.")
            if self.pipeline_parallel_size > 1:
                raise ValueError("Async scheduling is not supported with "
                                 "pipeline-parallel-size > 1.")

            # Currently, async scheduling does not support speculative decoding.
            # TODO(woosuk): Support it.
            if self.speculative_config is not None:
                raise ValueError(
                    "Currently, speculative decoding is not supported with "
                    "async scheduling.")

1213
        parallel_config = ParallelConfig(
1214
1215
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1216
            data_parallel_size=self.data_parallel_size,
1217
1218
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1219
1220
1221
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1222
            data_parallel_backend=self.data_parallel_backend,
1223
            enable_expert_parallel=self.enable_expert_parallel,
1224
1225
1226
1227
1228
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.num_redundant_experts,
            eplb_window_size=self.eplb_window_size,
            eplb_step_interval=self.eplb_step_interval,
            eplb_log_balancedness=self.eplb_log_balancedness,
1229
1230
1231
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1232
            placement_group=placement_group,
1233
1234
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1235
            worker_extension_cls=self.worker_extension_cls,
1236
1237
            enable_multimodal_encoder_data_parallel=self.
            enable_multimodal_encoder_data_parallel,
1238
        )
1239

1240
        speculative_config = self.create_speculative_config(
1241
1242
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1243
            enable_chunked_prefill=self.enable_chunked_prefill,
1244
            disable_log_stats=self.disable_log_stats,
1245
1246
        )

1247
        # Reminder: Please update docs/features/compatibility_matrix.md
1248
        # If the feature combo become valid
1249
1250
1251
1252
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1253
1254
1255
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1256
1257
1258
1259
1260
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1261
1262
1263
1264
1265
1266
1267
1268
1269

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1270
        scheduler_config = SchedulerConfig(
1271
            runner_type=model_config.runner_type,
1272
1273
1274
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1275
            cuda_graph_sizes=self.cuda_graph_sizes,
1276
            num_lookahead_slots=num_lookahead_slots,
1277
1278
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1279
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1280
            is_multimodal_model=model_config.is_multimodal_model,
1281
            preemption_mode=self.preemption_mode,
1282
            num_scheduler_steps=self.num_scheduler_steps,
1283
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1284
1285
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1286
            policy=self.scheduling_policy,
1287
            scheduler_cls=self.scheduler_cls,
1288
1289
1290
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1291
1292
            disable_hybrid_kv_cache_manager=self.
            disable_hybrid_kv_cache_manager,
1293
            async_scheduling=self.async_scheduling,
1294
        )
1295

1296
1297
1298
1299
1300
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
                "non multimodal model")

1301
        lora_config = LoRAConfig(
1302
            bias_enabled=self.enable_lora_bias,
1303
1304
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1305
            default_mm_loras=self.default_mm_loras,
1306
            fully_sharded_loras=self.fully_sharded_loras,
1307
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1308
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1309
1310
1311
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1312

1313
1314
1315
1316
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1317
        load_config = self.create_load_config()
1318

1319
1320
1321
1322
1323
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1324
        decoding_config = DecodingConfig(
1325
1326
1327
1328
1329
            backend=self.guided_decoding_backend,
            disable_fallback=self.guided_decoding_disable_fallback,
            disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
            disable_additional_properties=\
                self.guided_decoding_disable_additional_properties,
1330
1331
            reasoning_backend=self.reasoning_parser
        )
1332

1333
        observability_config = ObservabilityConfig(
1334
1335
            show_hidden_metrics_for_version=self.
            show_hidden_metrics_for_version,
1336
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1337
            collect_detailed_traces=self.collect_detailed_traces,
1338
        )
1339

1340
        config = VllmConfig(
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1351
            prompt_adapter_config=prompt_adapter_config,
1352
            compilation_config=self.compilation_config,
1353
            kv_transfer_config=self.kv_transfer_config,
1354
            kv_events_config=self.kv_events_config,
1355
            additional_config=self.additional_config,
1356
        )
1357

1358
1359
        return config

1360
1361
1362
1363
1364
1365
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1366
        if self.load_format == LoadFormat.SHARDED_STATE.value:
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

1378
        if self.preemption_mode != SchedulerConfig.preemption_mode:
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

1389
        if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
1390
1391
1392
1393
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

1394
        if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
1395
1396
1397
1398
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

1399
1400
        if self.guided_decoding_backend not in get_args(
                GuidedDecodingBackendV1):
1401
1402
1403
1404
            _raise_or_fallback(
                feature_name=
                f"--guided-decoding-backend={self.guided_decoding_backend}",
                recommend_to_remove=False)
1405
1406
1407
            return False

        # Need at least Ampere for now (FA support required).
1408
1409
1410
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1411
        if (current_platform.is_cuda()
1412
                and current_platform.get_device_capability()
1413
1414
1415
1416
1417
1418
1419
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
1420
1421
1422
1423
1424
1425
            fp8_attention = self.kv_cache_dtype.startswith("fp8")
            will_use_fa = (
                current_platform.is_cuda()
                and not envs.is_set("VLLM_ATTENTION_BACKEND")
            ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
            supported = False
1426
1427
1428
            if current_platform.is_rocm() or (
                    current_platform.is_cuda()
                    and current_platform.is_device_capability(100)):
1429
1430
                supported = True
            elif fp8_attention and will_use_fa:
1431
                from vllm.attention.utils.fa_utils import (
1432
1433
                    flash_attn_supports_fp8)
                supported = flash_attn_supports_fp8()
1434

1435
1436
1437
1438
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1439
1440
1441
1442
1443
1444
1445

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

1446
1447
1448
1449
1450
1451
        # No text embedding inputs so far.
        if self.enable_prompt_embeds:
            _raise_or_fallback(feature_name="--enable-prompt-embeds",
                               recommend_to_remove=False)
            return False

1452
        # No Mamba or Encoder-Decoder so far.
1453
1454
1455
1456
1457
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

Chen Zhang's avatar
Chen Zhang committed
1458
1459
1460
1461
1462
        # V1 mamba models are unoptimized.
        if model_config.has_inner_state and _warn_or_fallback(
                feature_name="Mamba"):
            return False

1463
1464
        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
1465
                != SchedulerConfig.max_num_partial_prefills
1466
                or self.max_long_partial_prefills
1467
                != SchedulerConfig.max_long_partial_prefills):
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

1478
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1479
        is_ngram_enabled = False
1480
        is_eagle_enabled = False
1481
        is_medusa_enabled = False
1482
        if self.speculative_config is not None:
1483
            # This is supported but experimental (handled below).
1484
1485
1486
1487
            speculative_method = self.speculative_config.get("method")
            if speculative_method:
                if speculative_method in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
1488
1489
                elif speculative_method == "medusa":
                    is_medusa_enabled = True
Jiayi Yao's avatar
Jiayi Yao committed
1490
                elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
1491
                    is_eagle_enabled = True
1492
            else:
1493
1494
1495
                speculative_model = self.speculative_config.get("model")
                if speculative_model in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
1496
            if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
1497
                # Other speculative decoding methods are not supported yet.
1498
1499
1500
1501
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

1502
        # No XFormers so far.
1503
        V1_BACKENDS = [
1504
1505
1506
1507
1508
1509
            "FLASH_ATTN_VLLM_V1",
            "FLASH_ATTN",
            "PALLAS",
            "PALLAS_VLLM_V1",
            "TRITON_ATTN_VLLM_V1",
            "TRITON_MLA",
1510
            "CUTLASS_MLA_VLLM_V1",
1511
1512
1513
            "FLASHMLA",
            "FLASHINFER",
            "FLASHINFER_VLLM_V1",
1514
            "ROCM_AITER_MLA",
1515
            "TORCH_SDPA_VLLM_V1",
1516
            "FLEX_ATTENTION",
1517
1518
1519
1520
1521
1522
1523
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1524
1525
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1526
1527
1528
1529
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1530
1531
1532
        #############################################################
        # Experimental Features - allow users to opt in.

1533
1534
1535
1536
1537
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1538
        if (self.pipeline_parallel_size > 1
1539
                and self.distributed_executor_backend
1540
1541
                not in (ParallelConfig.distributed_executor_backend, "ray",
                        "mp", "external_launcher")):
1542
            name = "Pipeline Parallelism without Ray distributed executor " \
1543
                    "or multiprocessing executor or external launcher"
1544
            _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1545
1546
            return False

1547
1548
1549
1550
        # The platform may be supported on V1, but off by default for now.
        if not current_platform.default_v1(  # noqa: SIM103
                model_config=model_config) and _warn_or_fallback(
                    current_platform.device_name):
1551
            return False
1552
1553
1554
1555
1556
1557
1558

        if (current_platform.is_cpu()
                and model_config.get_sliding_window() is not None):
            _raise_or_fallback(feature_name="sliding window (CPU backend)",
                               recommend_to_remove=False)
            return False

1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1579
                use_spec_decode = self.speculative_config is not None
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching:
            # Disable prefix caching for multimodal models for VLLM_V0.
            if model_config.is_multimodal_model:
                logger.warning(
                    "--enable-prefix-caching is not supported for multimodal "
                    "models in V0 and has been disabled.")
                self.enable_prefix_caching = False

            # VLLM_V0 only supports builtin hash algo for prefix caching.
1617
            if self.prefix_caching_hash_algo == "sha256":
1618
1619
1620
                raise ValueError(
                    "sha256 is not supported for prefix caching in V0 engine. "
                    "Please use 'builtin'.")
1621
1622
1623
1624
1625

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

1626
1627
    def _set_default_args_v1(self, usage_context: UsageContext,
                             model_config: ModelConfig) -> None:
1628
        """Set Default Arguments for V1 Engine."""
1629

1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
        # V1 always uses chunked prefills and prefix caching
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = True
        else:

            pooling_type = model_config.pooler_config.pooling_type

            # TODO: when encoder models are supported we'll have to
            # check for causal attention here.
            incremental_prefill_supported = (pooling_type is not None and
                                             pooling_type.lower() == "last")
1645

1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
            action = "Enabling" if \
                incremental_prefill_supported else "Disabling"

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)

        if not self.enable_chunked_prefill:
            self.max_num_batched_tokens = model_config.max_model_len
1658

1659
1660
1661
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1662
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1663

1664
1665
        # When no user override, set the default values based on the usage
        # context.
1666
        # Use different default values for different hardware.
1667
1668
1669
1670
1671
1672
1673

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1674
            device_memory = current_platform.get_device_total_memory()
1675
            device_name = current_platform.get_device_name().lower()
1676
1677
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1678
            device_memory = 0
1679

1680
1681
1682
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
1683
        from vllm.usage.usage_lib import UsageContext
1684
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1685
            # For GPUs like H100 and MI300x, use larger default values.
1686
1687
1688
1689
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1690
1691
1692
1693
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1694
1695
1696
1697
1698
1699
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1700
1701
1702
1703
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1704

1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
                    'V6E': 2048,
                    'V5E': 1024,
                    'V5P': 512,
                },
                UsageContext.OPENAI_API_SERVER: {
                    'V6E': 1024,
                    'V5E': 512,
                    'V5P': 256,
                }
            }

1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
        # cpu specific default values.
        if current_platform.is_cpu():
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 4096,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 128,
                UsageContext.OPENAI_API_SERVER: 32,
            }

1731
        use_context_value = usage_context.value if usage_context else None
1732
1733
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
                if chip_name in default_max_num_batched_tokens_tpu[
                        usage_context]:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens_tpu[
                            usage_context][chip_name]
                else:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens[usage_context]
            else:
                self.max_num_batched_tokens = default_max_num_batched_tokens[
                    usage_context]
1747
            logger.debug(
1748
                "Setting max_num_batched_tokens to %d for %s usage context.",
1749
                self.max_num_batched_tokens, use_context_value)
1750

1751
1752
1753
        if (self.max_num_seqs is None
                and usage_context in default_max_num_seqs):
            self.max_num_seqs = default_max_num_seqs[usage_context]
1754
1755
1756

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1757

1758

1759
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1760
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1761
    """Arguments for asynchronous vLLM engine."""
1762
    disable_log_requests: bool = False
1763
1764

    @staticmethod
1765
1766
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1767
1768
1769
1770
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1771
1772
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1773
1774
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1775
                            help='Disable logging requests.')
1776
        current_platform.pre_register_and_update(parser)
1777
        return parser
1778
1779


1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1807
1808
1809
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1810

1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
    match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
    if match:
        decimal_multiplier = {
            'k': 10**3,
            'm': 10**6,
            'g': 10**9,
        }
        binary_multiplier = {
            'K': 2**10,
            'M': 2**20,
            'G': 2**30,
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
                raise argparse.ArgumentTypeError("Decimals are not allowed " \
                f"with binary suffixes like {suffix}. Did you mean to use " \
                f"{number}{suffix.lower()} instead?") from e

    # Regular plain number.
    return int(value)


1848
1849
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1850
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1851
1852
1853


def _async_engine_args_parser():
1854
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1855
                                        async_args_only=True)