arg_utils.py 69.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
# yapf: disable
4
import argparse
5
import dataclasses
6
import json
7
import re
8
import threading
9
from dataclasses import MISSING, dataclass, fields
10
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
11
                    TypeVar, Union, cast, get_args, get_origin)
12

13
import torch
14
from typing_extensions import TypeIs, deprecated
15

16
import vllm.envs as envs
17
from vllm import version
18
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
19
                         ConfigFormat, ConfigType, DecodingConfig, Device,
20
                         DeviceConfig, DistributedExecutorBackend,
21
22
                         GuidedDecodingBackend, GuidedDecodingBackendV1,
                         HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
23
24
25
26
27
28
29
                         LoRAConfig, ModelConfig, ModelDType, ModelImpl,
                         MultiModalConfig, ObservabilityConfig, ParallelConfig,
                         PoolerConfig, PrefixCachingHashAlgo,
                         PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
                         SpeculativeConfig, TaskOption, TokenizerMode,
                         TokenizerPoolConfig, VllmConfig, get_attr_docs,
                         get_field)
30
from vllm.executor.executor_base import ExecutorBase
31
from vllm.logger import init_logger
32
from vllm.model_executor.layers.quantization import QuantizationMethods
33
from vllm.plugins import load_general_plugins
34
from vllm.reasoning import ReasoningParserManager
35
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
36
from vllm.transformers_utils.utils import check_gguf_file
37
from vllm.usage.usage_lib import UsageContext
38
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
39
40

# yapf: enable
41

42
43
logger = init_logger(__name__)

44
45
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

46
47
48
49
50
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]

51

52
53
def optional_type(
        return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
54

55
56
57
58
59
60
61
62
63
64
    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        try:
            if return_type is json.loads and not re.match("^{.*}$", val):
                return cast(T, nullable_kvs(val))
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
                f"Value {val} cannot be converted to {return_type}.") from e
65

66
    return _optional_type
67
68


69
70
71
72
73
74
@deprecated(
    "Passing a JSON argument as a string containing comma separated key=value "
    "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
    "string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
    """Parses a string containing comma separate key [str] to value [int]
75
76
77
78
79
80
81
82
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
83
    out_dict: dict[str, int] = {}
84
    for item in val.split(","):
85
86
87
88
89
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
90
91

        try:
92
            parsed_value = int(value)
93
94
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
95
96
97
98
99
100
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
101
102
103
104

    return out_dict


105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


120
121
122
123
124
125
126
127
128
129
130
131
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
    """Convert Literal type hints to argparse kwargs."""
    type_hint = get_type(type_hints, Literal)
    choices = get_args(type_hint)
    choice_type = type(choices[0])
    if not all(isinstance(choice, choice_type) for choice in choices):
        raise ValueError(
            "All choices must be of the same type. "
            f"Got {choices} with types {[type(c) for c in choices]}")
    return {"type": choice_type, "choices": sorted(choices)}


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    cls_docs = get_attr_docs(cls)
    kwargs = {}
    for field in fields(cls):
        # Get the default value of the field
        default = field.default
        if field.default_factory is not MISSING:
            default = field.default_factory()

        # Get the help text for the field
        name = field.name
        help = cls_docs[name]
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Get the set of possible types for the field
        type_hints: set[TypeHint] = set()
        if get_origin(field.type) is Union:
            type_hints.update(get_args(field.type))
        else:
            type_hints.add(field.type)

        # Set other kwargs based on the type hints
        if contains_type(type_hints, bool):
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
167
            kwargs[name].update(literal_to_kwargs(type_hints))
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
                f"type. Got {types}.")
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
            assert len(types) == 1, (
                "List type must have exactly one type. Got "
                f"{type_hint} with types {types}")
            kwargs[name]["type"] = types[0]
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
187
188
189
            # Special case for large integers
            if name in {"max_model_len"}:
                kwargs[name]["type"] = human_readable_int
190
191
192
193
194
195
196
197
198
199
200
201
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
        elif contains_type(type_hints, dict):
            # Dict arguments will always be optional
            kwargs[name]["type"] = optional_type(json.loads)
        elif (contains_type(type_hints, str)
              or any(is_not_builtin(th) for th in type_hints)):
            kwargs[name]["type"] = str
        else:
            raise ValueError(
                f"Unsupported type {type_hints} for argument {name}.")

202
203
204
205
206
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

207
208
209
210
211
212
213
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
214
215


216
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
217
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
218
    """Arguments for vLLM engine."""
219
220
221
222
223
224
225
226
227
228
    model: str = ModelConfig.model
    served_model_name: Optional[Union[
        str, List[str]]] = ModelConfig.served_model_name
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
    task: TaskOption = ModelConfig.task
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
229
230
    download_dir: Optional[str] = LoadConfig.download_dir
    load_format: str = LoadConfig.load_format
231
232
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
233
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
234
235
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
236
237
238
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
239
    distributed_executor_backend: Optional[Union[
240
241
        DistributedExecutorBackend,
        Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
242
    # number of P/D disaggregation (or other disaggregation) workers
243
244
245
246
247
248
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
    data_parallel_size: int = ParallelConfig.data_parallel_size
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
    max_parallel_loading_workers: Optional[
        int] = ParallelConfig.max_parallel_loading_workers
249
250
251
252
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
    prefix_caching_hash_algo: PrefixCachingHashAlgo = \
        CacheConfig.prefix_caching_hash_algo
253
254
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
255
    use_v2_block_manager: bool = True
256
257
258
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
259
260
261
262
263
264
265
    max_num_batched_tokens: Optional[
        int] = SchedulerConfig.max_num_batched_tokens
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
    long_prefill_token_threshold: int = \
        SchedulerConfig.long_prefill_token_threshold
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
266
    max_logprobs: int = ModelConfig.max_logprobs
267
    disable_log_stats: bool = False
268
269
270
271
272
273
274
275
276
277
278
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
    hf_overrides: Optional[HfOverrides] = \
        get_field(ModelConfig, "hf_overrides")
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
    max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
279
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
280
281
282
    # The following three fields are deprecated and will be removed in a future
    # release. Setting them will have no effect. Please remove them from your
    # configurations.
283
    tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
284
285
    tokenizer_pool_type: str = TokenizerPoolConfig.pool_type
    tokenizer_pool_extra_config: dict = \
286
        get_field(TokenizerPoolConfig, "extra_config")
287
    limit_mm_per_prompt: dict[str, int] = \
288
        get_field(MultiModalConfig, "limit_per_prompt")
289
290
291
292
    mm_processor_kwargs: Optional[Dict[str, Any]] = \
        MultiModalConfig.mm_processor_kwargs
    disable_mm_preprocessor_cache: bool = \
        MultiModalConfig.disable_mm_preprocessor_cache
293
    # LoRA fields
294
    enable_lora: bool = False
295
296
297
298
299
300
301
302
303
304
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
    long_lora_scaling_factors: Optional[tuple[float, ...]] = \
        LoRAConfig.long_lora_scaling_factors
    # PromptAdapter fields
305
    enable_prompt_adapter: bool = False
306
307
308
309
    max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
    max_prompt_adapter_token: int = \
        PromptAdapterConfig.max_prompt_adapter_token

310
    device: Device = DeviceConfig.device
311
312
    num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
    multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
313
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
314
315
    num_gpu_blocks_override: Optional[
        int] = CacheConfig.num_gpu_blocks_override
316
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
317
318
    model_loader_extra_config: dict = \
        get_field(LoadConfig, "model_loader_extra_config")
319
320
    ignore_patterns: Optional[Union[str,
                                    List[str]]] = LoadConfig.ignore_patterns
321
    preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
322

323
324
325
326
    scheduler_delay_factor: float = SchedulerConfig.delay_factor
    enable_chunked_prefill: Optional[
        bool] = SchedulerConfig.enable_chunked_prefill
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
327

328
329
330
331
332
333
    guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
    guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
    guided_decoding_disable_any_whitespace: bool = \
        DecodingConfig.disable_any_whitespace
    guided_decoding_disable_additional_properties: bool = \
        DecodingConfig.disable_additional_properties
334
335
    logits_processor_pattern: Optional[
        str] = ModelConfig.logits_processor_pattern
336

337
    speculative_config: Optional[Dict[str, Any]] = None
338

339
    qlora_adapter_name_or_path: Optional[str] = None
340
    show_hidden_metrics_for_version: Optional[str] = None
341
    otlp_traces_endpoint: Optional[str] = None
342
    collect_detailed_traces: Optional[str] = None
343
    disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
344
345
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
346

347
348
349
350
    override_neuron_config: dict[str, Any] = \
        get_field(ModelConfig, "override_neuron_config")
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
        ModelConfig.override_pooler_config
351
    compilation_config: Optional[CompilationConfig] = None
352
353
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
354

355
356
    kv_transfer_config: Optional[KVTransferConfig] = None

357
358
359
360
361
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
    override_generation_config: dict[str, Any] = \
        get_field(ModelConfig, "override_generation_config")
    model_impl: str = ModelConfig.model_impl
362

363
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
364

365
    additional_config: Optional[Dict[str, Any]] = None
366
    enable_reasoning: Optional[bool] = None
367
    reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend
368
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
369

370
    def __post_init__(self):
371
372
373
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
374
        if isinstance(self.compilation_config, (int, dict)):
375
376
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
377

378
        # Setup plugins
379
380
        from vllm.plugins import load_general_plugins
        load_general_plugins()
381
382

    @staticmethod
383
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
384
        """Shared CLI arguments for vLLM engine."""
385

386
        # Model arguments
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
        model_group.add_argument("--model", **model_kwargs["model"])
        model_group.add_argument("--task", **model_kwargs["task"])
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode",
                                 **model_kwargs["tokenizer_mode"])
        model_group.add_argument("--trust-remote-code",
                                 **model_kwargs["trust_remote_code"])
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
        model_group.add_argument("--hf-config-path",
                                 **model_kwargs["hf_config_path"])
        model_group.add_argument("--allowed-local-media-path",
                                 **model_kwargs["allowed_local_media_path"])
        model_group.add_argument("--revision", **model_kwargs["revision"])
        model_group.add_argument("--code-revision",
                                 **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling",
                                 **model_kwargs["rope_scaling"])
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
        model_group.add_argument("--tokenizer-revision",
                                 **model_kwargs["tokenizer_revision"])
        model_group.add_argument("--max-model-len",
                                 **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q",
                                 **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager",
                                 **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-seq-len-to-capture",
                                 **model_kwargs["max_seq_len_to_capture"])
        model_group.add_argument("--max-logprobs",
                                 **model_kwargs["max_logprobs"])
        model_group.add_argument("--disable-sliding-window",
                                 **model_kwargs["disable_sliding_window"])
        model_group.add_argument("--disable-cascade-attn",
                                 **model_kwargs["disable_cascade_attn"])
        model_group.add_argument("--skip-tokenizer-init",
                                 **model_kwargs["skip_tokenizer_init"])
        model_group.add_argument("--served-model-name",
                                 **model_kwargs["served_model_name"])
        # This one is a special case because it is the
        # opposite of ModelConfig.use_async_output_proc
        model_group.add_argument(
            "--disable-async-output-proc",
            action="store_true",
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
        model_group.add_argument("--config-format",
                                 choices=[f.value for f in ConfigFormat],
                                 **model_kwargs["config_format"])
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
        model_group.add_argument("--hf-token",
                                 type=str,
                                 nargs="?",
                                 const=True,
                                 default=model_kwargs["hf_token"]["default"],
                                 help=model_kwargs["hf_token"]["help"])
        model_group.add_argument("--hf-overrides",
                                 **model_kwargs["hf_overrides"])
        model_group.add_argument("--override-neuron-config",
                                 **model_kwargs["override_neuron_config"])
        model_group.add_argument("--override-pooler-config",
                                 **model_kwargs["override_pooler_config"])
        model_group.add_argument("--logits-processor-pattern",
                                 **model_kwargs["logits_processor_pattern"])
        model_group.add_argument("--generation-config",
                                 **model_kwargs["generation_config"])
        model_group.add_argument("--override-generation-config",
                                 **model_kwargs["override_generation_config"])
        model_group.add_argument("--enable-sleep-mode",
                                 **model_kwargs["enable_sleep_mode"])
        model_group.add_argument("--model-impl",
                                 choices=[f.value for f in ModelImpl],
                                 **model_kwargs["model_impl"])

468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
        load_group.add_argument('--load-format',
                                choices=[f.value for f in LoadFormat],
                                **load_kwargs["load_format"])
        load_group.add_argument('--download-dir',
                                **load_kwargs["download_dir"])
        load_group.add_argument('--model-loader-extra-config',
                                **load_kwargs["model_loader_extra_config"])
        load_group.add_argument('--use-tqdm-on-load',
                                **load_kwargs["use_tqdm_on_load"])

484
485
486
487
488
489
        # Guided decoding arguments
        guided_decoding_kwargs = get_kwargs(DecodingConfig)
        guided_decoding_group = parser.add_argument_group(
            title="DecodingConfig",
            description=DecodingConfig.__doc__,
        )
490
491
        guided_decoding_group.add_argument("--guided-decoding-backend",
                                           **guided_decoding_kwargs["backend"])
492
        guided_decoding_group.add_argument(
493
494
495
496
497
498
499
500
            "--guided-decoding-disable-fallback",
            **guided_decoding_kwargs["disable_fallback"])
        guided_decoding_group.add_argument(
            "--guided-decoding-disable-any-whitespace",
            **guided_decoding_kwargs["disable_any_whitespace"])
        guided_decoding_group.add_argument(
            "--guided-decoding-disable-additional-properties",
            **guided_decoding_kwargs["disable_additional_properties"])
501
502
503
504
505
506
        guided_decoding_group.add_argument(
            "--reasoning-parser",
            # This choices is a special case because it's not static
            choices=list(ReasoningParserManager.reasoning_parsers),
            **guided_decoding_kwargs["reasoning_backend"])

507
        # Parallel arguments
508
509
510
511
512
513
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
514
            '--distributed-executor-backend',
515
516
517
518
519
520
521
522
523
            **parallel_kwargs["distributed_executor_backend"])
        parallel_group.add_argument(
            '--pipeline-parallel-size', '-pp',
            **parallel_kwargs["pipeline_parallel_size"])
        parallel_group.add_argument('--tensor-parallel-size', '-tp',
                                    **parallel_kwargs["tensor_parallel_size"])
        parallel_group.add_argument('--data-parallel-size', '-dp',
                                    **parallel_kwargs["data_parallel_size"])
        parallel_group.add_argument(
524
            '--enable-expert-parallel',
525
526
            **parallel_kwargs["enable_expert_parallel"])
        parallel_group.add_argument(
527
            '--max-parallel-loading-workers',
528
529
            **parallel_kwargs["max_parallel_loading_workers"])
        parallel_group.add_argument(
530
            '--ray-workers-use-nsight',
531
532
533
534
            **parallel_kwargs["ray_workers_use_nsight"])
        parallel_group.add_argument(
            '--disable-custom-all-reduce',
            **parallel_kwargs["disable_custom_all_reduce"])
535

536
537
538
539
540
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
541
        )
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        cache_group.add_argument('--block-size', **cache_kwargs["block_size"])
        cache_group.add_argument('--gpu-memory-utilization',
                                 **cache_kwargs["gpu_memory_utilization"])
        cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"])
        cache_group.add_argument('--kv-cache-dtype',
                                 **cache_kwargs["cache_dtype"])
        cache_group.add_argument('--num-gpu-blocks-override',
                                 **cache_kwargs["num_gpu_blocks_override"])
        cache_group.add_argument("--enable-prefix-caching",
                                 **cache_kwargs["enable_prefix_caching"])
        cache_group.add_argument("--prefix-caching-hash-algo",
                                 **cache_kwargs["prefix_caching_hash_algo"])
        cache_group.add_argument('--cpu-offload-gb',
                                 **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument('--calculate-kv-scales',
                                 **cache_kwargs["calculate_kv_scales"])

559
560
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
561
                            default=True,
562
563
564
565
566
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
567

568
569
        parser.add_argument('--disable-log-stats',
                            action='store_true',
570
                            help='Disable logging statistics.')
571
572
573
574
575
576
577
578
579
580
581
582
583

        # Tokenizer arguments
        tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
        tokenizer_group = parser.add_argument_group(
            title="TokenizerPoolConfig",
            description=TokenizerPoolConfig.__doc__,
        )
        tokenizer_group.add_argument('--tokenizer-pool-size',
                                     **tokenizer_kwargs["pool_size"])
        tokenizer_group.add_argument('--tokenizer-pool-type',
                                     **tokenizer_kwargs["pool_type"])
        tokenizer_group.add_argument('--tokenizer-pool-extra-config',
                                     **tokenizer_kwargs["extra_config"])
584
585

        # Multimodal related configs
586
587
588
589
590
591
592
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
        multimodal_group.add_argument('--limit-mm-per-prompt',
                                      **multimodal_kwargs["limit_per_prompt"])
593
        multimodal_group.add_argument(
594
            '--mm-processor-kwargs',
595
596
            **multimodal_kwargs["mm_processor_kwargs"])
        multimodal_group.add_argument(
597
            '--disable-mm-preprocessor-cache',
598
            **multimodal_kwargs["disable_mm_preprocessor_cache"])
599

600
        # LoRA related configs
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
            '--enable-lora',
            action=argparse.BooleanOptionalAction,
            help='If True, enable handling of LoRA adapters.')
        lora_group.add_argument('--enable-lora-bias',
                                **lora_kwargs["bias_enabled"])
        lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
        lora_group.add_argument('--max-lora-rank',
                                **lora_kwargs["max_lora_rank"])
        lora_group.add_argument('--lora-extra-vocab-size',
                                **lora_kwargs["lora_extra_vocab_size"])
        lora_group.add_argument(
618
            '--lora-dtype',
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
            **lora_kwargs["lora_dtype"],
        )
        lora_group.add_argument('--long-lora-scaling-factors',
                                **lora_kwargs["long_lora_scaling_factors"])
        lora_group.add_argument('--max-cpu-loras',
                                **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument('--fully-sharded-loras',
                                **lora_kwargs["fully_sharded_loras"])

        # PromptAdapter related configs
        prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
        prompt_adapter_group = parser.add_argument_group(
            title="PromptAdapterConfig",
            description=PromptAdapterConfig.__doc__,
        )
        prompt_adapter_group.add_argument(
            '--enable-prompt-adapter',
            action=argparse.BooleanOptionalAction,
            help='If True, enable handling of PromptAdapters.')
        prompt_adapter_group.add_argument(
            '--max-prompt-adapters',
            **prompt_adapter_kwargs["max_prompt_adapters"])
        prompt_adapter_group.add_argument(
            '--max-prompt-adapter-token',
            **prompt_adapter_kwargs["max_prompt_adapter_token"])
644
645
646
647
648
649
650
651
652

        # Device arguments
        device_kwargs = get_kwargs(DeviceConfig)
        device_group = parser.add_argument_group(
            title="DeviceConfig",
            description=DeviceConfig.__doc__,
        )
        device_group.add_argument("--device", **device_kwargs["device"])

653
654
655
656
657
658
659
660
661
662
663
664
        # Speculative arguments
        speculative_group = parser.add_argument_group(
            title="SpeculativeConfig",
            description=SpeculativeConfig.__doc__,
        )
        speculative_group.add_argument(
            '--speculative-config',
            type=json.loads,
            default=None,
            help='The configurations for speculative decoding.'
            ' Should be a JSON string.')

665
666
667
668
669
670
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
671
            "Default to `original/**/*` to avoid repeated loading of llama's "
672
            "checkpoints.")
673

674
675
676
677
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
678

679
680
681
682
683
684
685
686
687
688
689
690
        parser.add_argument('--show-hidden-metrics-for-version',
                            type=str,
                            default=None,
                            help='Enable deprecated Prometheus metrics that '
                            'have been hidden since the specified version. '
                            'For example, if a previously deprecated metric '
                            'has been hidden since the v0.7.0 release, you '
                            'use --show-hidden-metrics-for-version=0.7 as a '
                            'temporary escape hatch while you migrate to new '
                            'metrics. The metric is likely to be removed '
                            'completely in an upcoming release.')

691
692
693
694
695
        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
696
697
698
699
700
701
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
702
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
703
704
705
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
706

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
            '--max-num-batched-tokens',
            **scheduler_kwargs["max_num_batched_tokens"])
        scheduler_group.add_argument('--max-num-seqs',
                                     **scheduler_kwargs["max_num_seqs"])
        scheduler_group.add_argument(
            "--max-num-partial-prefills",
            **scheduler_kwargs["max_num_partial_prefills"])
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
            **scheduler_kwargs["max_long_partial_prefills"])
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
            **scheduler_kwargs["long_prefill_token_threshold"])
        scheduler_group.add_argument('--num-lookahead-slots',
                                     **scheduler_kwargs["num_lookahead_slots"])
        scheduler_group.add_argument('--scheduler-delay-factor',
                                     **scheduler_kwargs["delay_factor"])
731
732
733
734
        scheduler_group.add_argument('--preemption-mode',
                                     **scheduler_kwargs["preemption_mode"])
        scheduler_group.add_argument('--num-scheduler-steps',
                                     **scheduler_kwargs["num_scheduler_steps"])
735
736
737
738
739
        scheduler_group.add_argument(
            '--multi-step-stream-outputs',
            **scheduler_kwargs["multi_step_stream_outputs"])
        scheduler_group.add_argument('--scheduling-policy',
                                     **scheduler_kwargs["policy"])
740
741
742
        scheduler_group.add_argument(
            '--enable-chunked-prefill',
            **scheduler_kwargs["enable_chunked_prefill"])
743
744
745
746
747
        scheduler_group.add_argument(
            "--disable-chunked-mm-input",
            **scheduler_kwargs["disable_chunked_mm_input"])
        parser.add_argument('--scheduler-cls',
                            **scheduler_kwargs["scheduler_cls"])
748

749
750
751
752
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
753
                            help='torch.compile configuration for the model. '
754
755
756
757
758
759
760
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
761
762
                            'use a JSON string, e.g. ``{"level": 3, '
                            '"cudagraph_capture_sizes": [1, 2, 4, 8]}``\n'
763
                            'Following the convention of traditional '
764
765
                            'compilers, using ``-O`` without space is also '
                            'supported. ``-O3`` is equivalent to ``-O 3``.')
766

767
768
769
770
771
772
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

773
774
775
776
777
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
778
779
780
781
782
783
784
        parser.add_argument(
            '--worker-extension-cls',
            type=str,
            default="",
            help='The worker extension class on top of the worker cls, '
            'it is useful if you just want to add new functions to the worker '
            'class without changing the existing functions.')
785

786
787
788
789
790
791
792
793
        parser.add_argument(
            "--additional-config",
            type=json.loads,
            default=None,
            help="Additional config for specified platform in JSON format. "
            "Different platforms may support different configs. Make sure the "
            "configs are valid for the platform you are using. The input format"
            " is like '{\"config_key\":\"config_value\"}'")
794
795
796
797
798
799
800
801
802

        parser.add_argument(
            "--enable-reasoning",
            action="store_true",
            default=False,
            help="Whether to enable reasoning_content for the model. "
            "If enabled, the model will be able to generate reasoning content."
        )

803
        return parser
804
805

    @classmethod
806
    def from_cli_args(cls, args: argparse.Namespace):
807
808
809
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
810
811
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
812

813
    def create_model_config(self) -> ModelConfig:
814
815
816
817
818
819
820
821
822
823
824
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

825
        return ModelConfig(
826
            model=self.model,
827
            hf_config_path=self.hf_config_path,
828
            task=self.task,
829
            tokenizer=self.tokenizer,
830
831
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
832
            allowed_local_media_path=self.allowed_local_media_path,
833
834
835
836
837
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
838
            rope_theta=self.rope_theta,
839
            hf_token=self.hf_token,
840
            hf_overrides=self.hf_overrides,
841
842
843
844
845
846
847
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
848
            disable_cascade_attn=self.disable_cascade_attn,
849
            skip_tokenizer_init=self.skip_tokenizer_init,
850
            served_model_name=self.served_model_name,
851
            limit_mm_per_prompt=self.limit_mm_per_prompt,
852
            use_async_output_proc=not self.disable_async_output_proc,
853
            config_format=self.config_format,
854
            mm_processor_kwargs=self.mm_processor_kwargs,
855
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
856
857
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
858
            logits_processor_pattern=self.logits_processor_pattern,
859
            generation_config=self.generation_config,
860
            override_generation_config=self.override_generation_config,
861
            enable_sleep_mode=self.enable_sleep_mode,
862
            model_impl=self.model_impl,
863
        )
864

865
866
    def create_load_config(self) -> LoadConfig:

867
        if(self.qlora_adapter_name_or_path is not None) and \
868
869
            self.quantization != "bitsandbytes":
            raise ValueError(
870
                "QLoRA adapter only support "
871
872
                f"'bitsandbytes' quantization, but got {self.quantization}")

873
874
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
875
876
877
878
879
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
880
            use_tqdm_on_load=self.use_tqdm_on_load,
881
        )
882

883
884
885
886
887
888
889
890
891
892
893
894
895
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
896
        dictionary from the engine.
897
898
        """
        if self.speculative_config is None:
899
900
            return None

901
902
903
904
905
906
907
908
909
910
911
912
913
914
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
        speculative_config = SpeculativeConfig.from_dict(
            self.speculative_config)

        return speculative_config

915
916
917
918
919
920
921
922
923
924
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
925

926
927
928
929
930
931
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
932
933
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update()
934

935
        device_config = DeviceConfig(device=self.device)
936
937
        model_config = self.create_model_config()

938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
            self._set_default_args_v1(usage_context)
        else:
            self._set_default_args_v0(model_config)
960

961
962
        assert self.enable_chunked_prefill is not None

963
        cache_config = CacheConfig(
964
            block_size=self.block_size,
965
966
967
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
968
            is_attention_free=model_config.is_attention_free,
969
970
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
971
            enable_prefix_caching=self.enable_prefix_caching,
972
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
973
            cpu_offload_gb=self.cpu_offload_gb,
974
            calculate_kv_scales=self.calculate_kv_scales,
975
        )
976
977
978
979
980
981
982
983
984
985
986
987

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

988
        parallel_config = ParallelConfig(
989
990
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
991
            data_parallel_size=self.data_parallel_size,
992
            enable_expert_parallel=self.enable_expert_parallel,
993
994
995
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
996
            placement_group=placement_group,
997
998
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
999
            worker_extension_cls=self.worker_extension_cls,
1000
        )
1001

1002
        speculative_config = self.create_speculative_config(
1003
1004
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1005
            enable_chunked_prefill=self.enable_chunked_prefill,
1006
            disable_log_stats=self.disable_log_stats,
1007
1008
        )

1009
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1010
        # If the feature combo become valid
1011
1012
1013
1014
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1015
1016
1017
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1018
1019
1020
1021
1022
1023
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1024
1025
1026
1027
1028
1029
1030
1031
1032

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1033
        scheduler_config = SchedulerConfig(
1034
            runner_type=model_config.runner_type,
1035
1036
1037
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1038
            num_lookahead_slots=num_lookahead_slots,
1039
1040
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1041
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1042
            is_multimodal_model=model_config.is_multimodal_model,
1043
            preemption_mode=self.preemption_mode,
1044
            num_scheduler_steps=self.num_scheduler_steps,
1045
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1046
1047
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1048
            policy=self.scheduling_policy,
1049
            scheduler_cls=self.scheduler_cls,
1050
1051
1052
1053
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
        )
1054

1055
        lora_config = LoRAConfig(
1056
            bias_enabled=self.enable_lora_bias,
1057
1058
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1059
            fully_sharded_loras=self.fully_sharded_loras,
1060
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1061
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1062
1063
1064
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1065

1066
1067
1068
1069
1070
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1071
1072
1073
1074
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1075
        load_config = self.create_load_config()
1076

1077
1078
1079
1080
1081
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1082
        decoding_config = DecodingConfig(
1083
1084
1085
1086
1087
            backend=self.guided_decoding_backend,
            disable_fallback=self.guided_decoding_disable_fallback,
            disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
            disable_additional_properties=\
                self.guided_decoding_disable_additional_properties,
1088
1089
1090
            reasoning_backend=self.reasoning_parser
            if self.enable_reasoning else None,
        )
1091

1092
1093
1094
1095
1096
        show_hidden_metrics = False
        if self.show_hidden_metrics_for_version is not None:
            show_hidden_metrics = version._prev_minor_version_was(
                self.show_hidden_metrics_for_version)

1097
1098
1099
1100
1101
1102
1103
1104
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1105
        observability_config = ObservabilityConfig(
1106
            show_hidden_metrics=show_hidden_metrics,
1107
1108
1109
1110
1111
1112
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1113

1114
        config = VllmConfig(
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1125
            prompt_adapter_config=prompt_adapter_config,
1126
            compilation_config=self.compilation_config,
1127
            kv_transfer_config=self.kv_transfer_config,
1128
            additional_config=self.additional_config,
1129
        )
1130

1131
1132
        return config

1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

        if (self.load_format == LoadFormat.TENSORIZER.value
                or self.load_format == LoadFormat.SHARDED_STATE.value):
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

1152
        if self.preemption_mode != SchedulerConfig.preemption_mode:
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

1163
        if self.scheduling_policy != SchedulerConfig.policy:
1164
1165
1166
1167
            _raise_or_fallback(feature_name="--scheduling-policy",
                               recommend_to_remove=False)
            return False

1168
        if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
1169
1170
1171
1172
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

1173
        if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
1174
1175
1176
1177
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

1178
1179
        if self.guided_decoding_backend not in get_args(
                GuidedDecodingBackendV1):
1180
1181
1182
1183
            _raise_or_fallback(
                feature_name=
                f"--guided-decoding-backend={self.guided_decoding_backend}",
                recommend_to_remove=False)
1184
1185
1186
            return False

        # Need at least Ampere for now (FA support required).
1187
1188
1189
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1190
1191
        from vllm.platforms import current_platform
        if (current_platform.is_cuda()
1192
                and current_platform.get_device_capability()
1193
1194
1195
1196
1197
1198
1199
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
1200
1201
1202
1203
1204
1205
1206
            fp8_attention = self.kv_cache_dtype.startswith("fp8")
            will_use_fa = (
                current_platform.is_cuda()
                and not envs.is_set("VLLM_ATTENTION_BACKEND")
            ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
            supported = False
            if fp8_attention and will_use_fa:
1207
                from vllm.attention.utils.fa_utils import (
1208
1209
1210
1211
1212
1213
                    flash_attn_supports_fp8)
                supported = flash_attn_supports_fp8()
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

        # Only Fp16 and Bf16 dtypes since we only support FA.
        V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
        if model_config.dtype not in V1_SUPPORTED_DTYPES:
            _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
                               recommend_to_remove=False)
            return False

        # Some quantization is not compatible with torch.compile.
1229
        V1_UNSUPPORTED_QUANT = ["gguf"]
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
        if model_config.quantization in V1_UNSUPPORTED_QUANT:
            _raise_or_fallback(
                feature_name=f"--quantization {model_config.quantization}",
                recommend_to_remove=False)
            return False

        # No Embedding Models so far.
        if model_config.task not in ["generate"]:
            _raise_or_fallback(feature_name=f"--task {model_config.task}",
                               recommend_to_remove=False)
            return False

        # No Mamba or Encoder-Decoder so far.
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
1250
                != SchedulerConfig.max_num_partial_prefills
1251
                or self.max_long_partial_prefills
1252
                != SchedulerConfig.max_long_partial_prefills):
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

        # Only Ngram speculative decoding so far.
1264
        is_ngram_enabled = False
1265
        is_eagle_enabled = False
1266
        if self.speculative_config is not None:
1267
            # This is supported but experimental (handled below).
1268
1269
1270
1271
            speculative_method = self.speculative_config.get("method")
            if speculative_method:
                if speculative_method in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
1272
                elif speculative_method in ("eagle", "eagle3"):
1273
                    is_eagle_enabled = True
1274
            else:
1275
1276
1277
1278
1279
                speculative_model = self.speculative_config.get("model")
                if speculative_model in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
            if not (is_ngram_enabled or is_eagle_enabled):
                # Other speculative decoding methods are not supported yet.
1280
1281
1282
1283
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

1284
        # No XFormers so far.
1285
        V1_BACKENDS = [
1286
1287
1288
1289
1290
1291
1292
1293
1294
            "FLASH_ATTN_VLLM_V1",
            "FLASH_ATTN",
            "PALLAS",
            "PALLAS_VLLM_V1",
            "TRITON_ATTN_VLLM_V1",
            "TRITON_MLA",
            "FLASHMLA",
            "FLASHINFER",
            "FLASHINFER_VLLM_V1",
1295
1296
1297
1298
1299
1300
1301
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1302
1303
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1304
1305
1306
1307
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1308
1309
1310
        #############################################################
        # Experimental Features - allow users to opt in.

1311
1312
1313
1314
1315
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1316
1317
1318
        # PP is supported on V1 with Ray distributed executor,
        # but off for MP distributed executor for now.
        if (self.pipeline_parallel_size > 1
1319
1320
1321
                and self.distributed_executor_backend != "ray"):
            name = "Pipeline Parallelism without Ray distributed executor"
            _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1322
1323
1324
            return False

        # ngram is supported on V1, but off by default for now.
1325
        if is_ngram_enabled and _warn_or_fallback("ngram"):
1326
1327
            return False

1328
1329
1330
1331
        # Eagle is under development, so we don't support it yet.
        if is_eagle_enabled and _warn_or_fallback("Eagle"):
            return False

1332
1333
1334
        # Non-CUDA is supported on V1, but off by default for now.
        not_cuda = not current_platform.is_cuda()
        if not_cuda and _warn_or_fallback(  # noqa: SIM103
1335
                current_platform.device_name):
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
            return False
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                from vllm.platforms import current_platform
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1358
                use_spec_decode = self.speculative_config is not None
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching:
            # Disable prefix caching for multimodal models for VLLM_V0.
            if model_config.is_multimodal_model:
                logger.warning(
                    "--enable-prefix-caching is not supported for multimodal "
                    "models in V0 and has been disabled.")
                self.enable_prefix_caching = False

            # VLLM_V0 only supports builtin hash algo for prefix caching.
1396
            if self.prefix_caching_hash_algo == "sha256":
1397
1398
1399
                raise ValueError(
                    "sha256 is not supported for prefix caching in V0 engine. "
                    "Please use 'builtin'.")
1400
1401
1402
1403
1404
1405
1406

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

    def _set_default_args_v1(self, usage_context: UsageContext) -> None:
        """Set Default Arguments for V1 Engine."""
1407

1408
1409
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
1410
1411
1412
1413
1414

        # V1 enables prefix caching by default.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = True

1415
1416
1417
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1418
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1419

1420
1421
        # When no user override, set the default values based on the usage
        # context.
1422
        # Use different default values for different hardware.
1423
1424
1425
1426
1427
1428
1429
1430

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
            from vllm.platforms import current_platform
1431
            device_memory = current_platform.get_device_total_memory()
1432
1433
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1434
            device_memory = 0
1435

1436
1437
        if device_memory >= 70 * GiB_bytes:
            # For GPUs like H100 and MI300x, use larger default values.
1438
1439
1440
1441
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1442
            default_max_num_seqs = 1024
1443
1444
1445
1446
1447
1448
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1449
            default_max_num_seqs = 256
1450

1451
        use_context_value = usage_context.value if usage_context else None
1452
1453
1454
1455
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1456
            logger.debug(
1457
                "Setting max_num_batched_tokens to %d for %s usage context.",
1458
                self.max_num_batched_tokens, use_context_value)
1459

1460
1461
1462
1463
1464
        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1465

1466

1467
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1468
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1469
    """Arguments for asynchronous vLLM engine."""
1470
    disable_log_requests: bool = False
1471
1472

    @staticmethod
1473
1474
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1475
1476
1477
1478
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1479
1480
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1481
1482
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1483
                            help='Disable logging requests.')
1484
1485
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update(parser)
1486
        return parser
1487
1488


1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
    match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
    if match:
        decimal_multiplier = {
            'k': 10**3,
            'm': 10**6,
            'g': 10**9,
        }
        binary_multiplier = {
            'K': 2**10,
            'M': 2**20,
            'G': 2**30,
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
                raise argparse.ArgumentTypeError("Decimals are not allowed " \
                f"with binary suffixes like {suffix}. Did you mean to use " \
                f"{number}{suffix.lower()} instead?") from e

    # Regular plain number.
    return int(value)


1557
1558
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1559
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1560
1561
1562


def _async_engine_args_parser():
1563
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1564
                                        async_args_only=True)