arg_utils.py 78.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# yapf: disable
5
import argparse
6
import copy
7
import dataclasses
8
import functools
9
import json
10
import sys
11
import threading
12
from dataclasses import MISSING, dataclass, fields, is_dataclass
13
from itertools import permutations
14
15
16
from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List,
                    Literal, Optional, Type, TypeVar, Union, cast, get_args,
                    get_origin)
17

18
import regex as re
19
import torch
20
from pydantic import TypeAdapter, ValidationError
21
from typing_extensions import TypeIs
22

23
import vllm.envs as envs
24
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
25
26
27
28
                         ConfigFormat, ConfigType, DecodingConfig,
                         DetailedTraceModules, Device, DeviceConfig,
                         DistributedExecutorBackend, GuidedDecodingBackend,
                         GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
29
30
31
32
33
34
35
                         KVTransferConfig, LoadConfig, LoadFormat,
                         LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
                         ModelImpl, MultiModalConfig, ObservabilityConfig,
                         ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
                         PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
                         SpeculativeConfig, TaskOption, TokenizerMode,
                         VllmConfig, get_attr_docs, get_field)
36
from vllm.logger import init_logger
37
from vllm.platforms import CpuArchEnum, current_platform
38
from vllm.plugins import load_general_plugins
39
from vllm.reasoning import ReasoningParserManager
40
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
41
from vllm.transformers_utils.utils import check_gguf_file
42
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
Rui Qiao's avatar
Rui Qiao committed
43
                        GiB_bytes, get_ip, is_in_ray_actor)
44
45

# yapf: enable
46

47
48
49
50
51
52
53
54
55
if TYPE_CHECKING:
    from vllm.executor.executor_base import ExecutorBase
    from vllm.model_executor.layers.quantization import QuantizationMethods
    from vllm.usage.usage_lib import UsageContext
else:
    ExecutorBase = Any
    QuantizationMethods = Any
    UsageContext = Any

56
57
logger = init_logger(__name__)

58
59
60
61
62
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]

63

64
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
65

66
    def _parse_type(val: str) -> T:
67
68
69
70
71
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
                f"Value {val} cannot be converted to {return_type}.") from e
72

73
74
75
76
77
78
79
80
81
82
83
    return _parse_type


def optional_type(
        return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:

    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

84
    return _optional_type
85
86


87
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
88
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
89
        return str(val)
90
    return optional_type(json.loads)(val)
91
92


93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


108
109
110
111
112
113
114
115
116
117
118
119
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
    """Convert Literal type hints to argparse kwargs."""
    type_hint = get_type(type_hints, Literal)
    choices = get_args(type_hint)
    choice_type = type(choices[0])
    if not all(isinstance(choice, choice_type) for choice in choices):
        raise ValueError(
            "All choices must be of the same type. "
            f"Got {choices} with types {[type(c) for c in choices]}")
    return {"type": choice_type, "choices": sorted(choices)}


120
121
122
123
124
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
    elif origin is Union:
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


142
143
144
145
def is_online_quantization(quantization: Any) -> bool:
    return quantization in ["inc"]


146
147
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
148
149
150
    cls_docs = get_attr_docs(cls)
    kwargs = {}
    for field in fields(cls):
151
        # Get the set of possible types for the field
152
        type_hints: set[TypeHint] = get_type_hints(field.type)
153
154
155
156
157

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

158
        # Get the default value of the field
159
160
161
        if field.default is not MISSING:
            default = field.default
        elif field.default_factory is not MISSING:
162
            default = field.default_factory()
163
164
165

        # Get the help text for the field
        name = field.name
166
        help = cls_docs[name].strip()
167
168
169
170
171
172
173
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
174
175
176
177
178
179
180
181
182
183
184
        json_tip = """Should either be a valid JSON string or JSON keys
passed individually. For example, the following sets of arguments are
equivalent:

- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`

Additionally, list elements can be passed individually using `+`:

- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`"""
185
        if dataclass_cls is not None:
186
187
188
189
190
191
192
193
194
195

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    if hasattr(cls, "from_cli"):
                        return cls.from_cli(val)
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
196
            kwargs[name]["help"] += f"\n\n{json_tip}"
197
        elif contains_type(type_hints, bool):
198
199
200
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
201
            kwargs[name].update(literal_to_kwargs(type_hints))
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
                f"type. Got {types}.")
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
            assert len(types) == 1, (
                "List type must have exactly one type. Got "
                f"{type_hint} with types {types}")
            kwargs[name]["type"] = types[0]
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
221
            # Special case for large integers
222
            if name in {"max_model_len", "max_num_batched_tokens"}:
223
                kwargs[name]["type"] = human_readable_int
224
225
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
226
227
228
        elif (contains_type(type_hints, dict)
              and (contains_type(type_hints, str)
                   or any(is_not_builtin(th) for th in type_hints))):
229
            kwargs[name]["type"] = union_dict_and_str
230
        elif contains_type(type_hints, dict):
231
            kwargs[name]["type"] = parse_type(json.loads)
232
            kwargs[name]["help"] += f"\n\n{json_tip}"
233
234
235
236
237
238
239
        elif (contains_type(type_hints, str)
              or any(is_not_builtin(th) for th in type_hints)):
            kwargs[name]["type"] = str
        else:
            raise ValueError(
                f"Unsupported type {type_hints} for argument {name}.")

240
241
242
243
244
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

245
246
247
248
249
250
251
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
252
253


254
255
256
257
258
259
260
261
262
263
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    """Return argparse kwargs for the given Config dataclass.

    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


264
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
265
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
266
    """Arguments for vLLM engine."""
267
268
269
270
271
272
273
    model: str = ModelConfig.model
    served_model_name: Optional[Union[
        str, List[str]]] = ModelConfig.served_model_name
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
    task: TaskOption = ModelConfig.task
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
274
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
275
276
277
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
278
279
    download_dir: Optional[str] = LoadConfig.download_dir
    load_format: str = LoadConfig.load_format
280
281
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
282
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
283
284
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
285
286
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig,
                                            "cuda_graph_sizes")
287
288
289
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
290
    distributed_executor_backend: Optional[Union[
291
292
        DistributedExecutorBackend,
        Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
293
    # number of P/D disaggregation (or other disaggregation) workers
294
295
296
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
    data_parallel_size: int = ParallelConfig.data_parallel_size
297
    data_parallel_rank: Optional[int] = None
298
299
300
    data_parallel_size_local: Optional[int] = None
    data_parallel_address: Optional[str] = None
    data_parallel_rpc_port: Optional[int] = None
Rui Qiao's avatar
Rui Qiao committed
301
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
302
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
303
304
305
306
307
    enable_eplb: bool = ParallelConfig.enable_eplb
    num_redundant_experts: int = ParallelConfig.num_redundant_experts
    eplb_window_size: int = ParallelConfig.eplb_window_size
    eplb_step_interval: int = ParallelConfig.eplb_step_interval
    eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness
308
309
    max_parallel_loading_workers: Optional[
        int] = ParallelConfig.max_parallel_loading_workers
310
311
312
313
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
    prefix_caching_hash_algo: PrefixCachingHashAlgo = \
        CacheConfig.prefix_caching_hash_algo
314
315
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
316
317
318
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
319
320
321
322
323
324
325
    max_num_batched_tokens: Optional[
        int] = SchedulerConfig.max_num_batched_tokens
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
    long_prefill_token_threshold: int = \
        SchedulerConfig.long_prefill_token_threshold
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
326
    max_logprobs: int = ModelConfig.max_logprobs
327
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
328
    disable_log_stats: bool = False
329
330
331
332
333
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
334
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
335
336
337
338
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
    max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
339
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
340
    limit_mm_per_prompt: dict[str, int] = \
341
        get_field(MultiModalConfig, "limit_per_prompt")
342
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
343
344
345
    media_io_kwargs: dict[str, dict[str,
                                    Any]] = get_field(MultiModalConfig,
                                                      "media_io_kwargs")
346
347
348
349
    mm_processor_kwargs: Optional[Dict[str, Any]] = \
        MultiModalConfig.mm_processor_kwargs
    disable_mm_preprocessor_cache: bool = \
        MultiModalConfig.disable_mm_preprocessor_cache
350
    # LoRA fields
351
    enable_lora: bool = False
352
353
354
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
355
356
    default_mm_loras: Optional[Dict[str, str]] = \
        LoRAConfig.default_mm_loras
357
358
359
360
361
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
    # PromptAdapter fields
362
    enable_prompt_adapter: bool = False
363
364
365
366
    max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
    max_prompt_adapter_token: int = \
        PromptAdapterConfig.max_prompt_adapter_token

367
368
    num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
    multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
369
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
370
371
    num_gpu_blocks_override: Optional[
        int] = CacheConfig.num_gpu_blocks_override
372
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
373
374
    model_loader_extra_config: dict = \
        get_field(LoadConfig, "model_loader_extra_config")
375
376
    ignore_patterns: Optional[Union[str,
                                    List[str]]] = LoadConfig.ignore_patterns
377
    preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
378

379
380
381
382
    scheduler_delay_factor: float = SchedulerConfig.delay_factor
    enable_chunked_prefill: Optional[
        bool] = SchedulerConfig.enable_chunked_prefill
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
383

384
385
386
    disable_hybrid_kv_cache_manager: bool = (
        SchedulerConfig.disable_hybrid_kv_cache_manager)

387
388
389
390
391
392
    guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
    guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
    guided_decoding_disable_any_whitespace: bool = \
        DecodingConfig.disable_any_whitespace
    guided_decoding_disable_additional_properties: bool = \
        DecodingConfig.disable_additional_properties
393
394
    logits_processor_pattern: Optional[
        str] = ModelConfig.logits_processor_pattern
395

396
    speculative_config: Optional[Dict[str, Any]] = None
397

398
399
400
401
402
403
    show_hidden_metrics_for_version: Optional[str] = \
        ObservabilityConfig.show_hidden_metrics_for_version
    otlp_traces_endpoint: Optional[str] = \
        ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: Optional[list[DetailedTraceModules]] = \
        ObservabilityConfig.collect_detailed_traces
404
    disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
405
406
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
407

408
409
410
411
    override_neuron_config: dict[str, Any] = \
        get_field(ModelConfig, "override_neuron_config")
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
        ModelConfig.override_pooler_config
412
413
    compilation_config: CompilationConfig = \
        get_field(VllmConfig, "compilation_config")
414
415
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
416

417
    kv_transfer_config: Optional[KVTransferConfig] = None
418
    kv_events_config: Optional[KVEventsConfig] = None
419

420
421
422
423
424
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
    override_generation_config: dict[str, Any] = \
        get_field(ModelConfig, "override_generation_config")
    model_impl: str = ModelConfig.model_impl
425
    override_attention_dtype: str = ModelConfig.override_attention_dtype
426

427
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
428

429
430
    additional_config: dict[str, Any] = \
        get_field(VllmConfig, "additional_config")
431
432
    reasoning_parser: str = DecodingConfig.reasoning_backend

433
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
434
    pt_load_map_location: str = LoadConfig.pt_load_map_location
435

436
437
438
    enable_multimodal_encoder_data_parallel: bool = \
        ParallelConfig.enable_multimodal_encoder_data_parallel

439
440
    async_scheduling: bool = SchedulerConfig.async_scheduling

441
    def __post_init__(self):
442
443
444
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
445
        if isinstance(self.compilation_config, (int, dict)):
446
447
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
448
        # Setup plugins
449
450
        from vllm.plugins import load_general_plugins
        load_general_plugins()
451
452

    @staticmethod
453
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
454
        """Shared CLI arguments for vLLM engine."""
455

456
        # Model arguments
457
458
459
460
461
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
Reid's avatar
Reid committed
462
        if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]):
463
            model_group.add_argument("--model", **model_kwargs["model"])
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        model_group.add_argument("--task", **model_kwargs["task"])
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode",
                                 **model_kwargs["tokenizer_mode"])
        model_group.add_argument("--trust-remote-code",
                                 **model_kwargs["trust_remote_code"])
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
        model_group.add_argument("--hf-config-path",
                                 **model_kwargs["hf_config_path"])
        model_group.add_argument("--allowed-local-media-path",
                                 **model_kwargs["allowed_local_media_path"])
        model_group.add_argument("--revision", **model_kwargs["revision"])
        model_group.add_argument("--code-revision",
                                 **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling",
                                 **model_kwargs["rope_scaling"])
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
        model_group.add_argument("--tokenizer-revision",
                                 **model_kwargs["tokenizer_revision"])
        model_group.add_argument("--max-model-len",
                                 **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q",
                                 **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager",
                                 **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-seq-len-to-capture",
                                 **model_kwargs["max_seq_len_to_capture"])
        model_group.add_argument("--max-logprobs",
                                 **model_kwargs["max_logprobs"])
494
495
        model_group.add_argument("--logprobs-mode",
                                 **model_kwargs["logprobs_mode"])
496
497
498
499
500
501
        model_group.add_argument("--disable-sliding-window",
                                 **model_kwargs["disable_sliding_window"])
        model_group.add_argument("--disable-cascade-attn",
                                 **model_kwargs["disable_cascade_attn"])
        model_group.add_argument("--skip-tokenizer-init",
                                 **model_kwargs["skip_tokenizer_init"])
502
503
        model_group.add_argument("--enable-prompt-embeds",
                                 **model_kwargs["enable_prompt_embeds"])
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
        model_group.add_argument("--served-model-name",
                                 **model_kwargs["served_model_name"])
        # This one is a special case because it is the
        # opposite of ModelConfig.use_async_output_proc
        model_group.add_argument(
            "--disable-async-output-proc",
            action="store_true",
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
        model_group.add_argument("--config-format",
                                 choices=[f.value for f in ConfigFormat],
                                 **model_kwargs["config_format"])
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
        model_group.add_argument("--hf-token",
                                 type=str,
                                 nargs="?",
                                 const=True,
                                 default=model_kwargs["hf_token"]["default"],
                                 help=model_kwargs["hf_token"]["help"])
        model_group.add_argument("--hf-overrides",
                                 **model_kwargs["hf_overrides"])
        model_group.add_argument("--override-neuron-config",
                                 **model_kwargs["override_neuron_config"])
        model_group.add_argument("--override-pooler-config",
                                 **model_kwargs["override_pooler_config"])
        model_group.add_argument("--logits-processor-pattern",
                                 **model_kwargs["logits_processor_pattern"])
        model_group.add_argument("--generation-config",
                                 **model_kwargs["generation_config"])
        model_group.add_argument("--override-generation-config",
                                 **model_kwargs["override_generation_config"])
        model_group.add_argument("--enable-sleep-mode",
                                 **model_kwargs["enable_sleep_mode"])
        model_group.add_argument("--model-impl",
                                 choices=[f.value for f in ModelImpl],
                                 **model_kwargs["model_impl"])
542
543
        model_group.add_argument("--override-attention-dtype",
                                 **model_kwargs["override_attention_dtype"])
544

545
546
547
548
549
550
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
551
        load_group.add_argument("--load-format",
552
553
                                choices=[f.value for f in LoadFormat],
                                **load_kwargs["load_format"])
554
        load_group.add_argument("--download-dir",
555
                                **load_kwargs["download_dir"])
556
        load_group.add_argument("--model-loader-extra-config",
557
                                **load_kwargs["model_loader_extra_config"])
558
559
560
        load_group.add_argument("--ignore-patterns",
                                **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load",
561
                                **load_kwargs["use_tqdm_on_load"])
562
563
        load_group.add_argument('--pt-load-map-location',
                                **load_kwargs["pt_load_map_location"])
564

565
566
567
568
569
570
        # Guided decoding arguments
        guided_decoding_kwargs = get_kwargs(DecodingConfig)
        guided_decoding_group = parser.add_argument_group(
            title="DecodingConfig",
            description=DecodingConfig.__doc__,
        )
571
572
        guided_decoding_group.add_argument("--guided-decoding-backend",
                                           **guided_decoding_kwargs["backend"])
573
        guided_decoding_group.add_argument(
574
575
576
577
578
579
580
581
            "--guided-decoding-disable-fallback",
            **guided_decoding_kwargs["disable_fallback"])
        guided_decoding_group.add_argument(
            "--guided-decoding-disable-any-whitespace",
            **guided_decoding_kwargs["disable_any_whitespace"])
        guided_decoding_group.add_argument(
            "--guided-decoding-disable-additional-properties",
            **guided_decoding_kwargs["disable_additional_properties"])
582
583
584
585
586
587
        guided_decoding_group.add_argument(
            "--reasoning-parser",
            # This choices is a special case because it's not static
            choices=list(ReasoningParserManager.reasoning_parsers),
            **guided_decoding_kwargs["reasoning_backend"])

588
        # Parallel arguments
589
590
591
592
593
594
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
595
            "--distributed-executor-backend",
596
597
            **parallel_kwargs["distributed_executor_backend"])
        parallel_group.add_argument(
598
            "--pipeline-parallel-size", "-pp",
599
            **parallel_kwargs["pipeline_parallel_size"])
600
        parallel_group.add_argument("--tensor-parallel-size", "-tp",
601
                                    **parallel_kwargs["tensor_parallel_size"])
602
        parallel_group.add_argument("--data-parallel-size", "-dp",
603
                                    **parallel_kwargs["data_parallel_size"])
604
605
606
607
608
609
        parallel_group.add_argument(
            '--data-parallel-rank',
            '-dpn',
            type=int,
            help='Data parallel rank of this instance. '
            'When set, enables external load balancer mode.')
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        parallel_group.add_argument('--data-parallel-size-local',
                                    '-dpl',
                                    type=int,
                                    help='Number of data parallel replicas '
                                    'to run on this node.')
        parallel_group.add_argument('--data-parallel-address',
                                    '-dpa',
                                    type=str,
                                    help='Address of data parallel cluster '
                                    'head-node.')
        parallel_group.add_argument('--data-parallel-rpc-port',
                                    '-dpp',
                                    type=int,
                                    help='Port for data parallel RPC '
                                    'communication.')
Rui Qiao's avatar
Rui Qiao committed
625
626
627
628
629
630
        parallel_group.add_argument('--data-parallel-backend',
                                    '-dpb',
                                    type=str,
                                    default='mp',
                                    help='Backend for data parallel, either '
                                    '"mp" or "ray".')
631
        parallel_group.add_argument(
632
            "--enable-expert-parallel",
633
            **parallel_kwargs["enable_expert_parallel"])
634
635
636
637
638
639
640
641
642
643
        parallel_group.add_argument("--enable-eplb",
                                    **parallel_kwargs["enable_eplb"])
        parallel_group.add_argument("--num-redundant-experts",
                                    **parallel_kwargs["num_redundant_experts"])
        parallel_group.add_argument("--eplb-window-size",
                                    **parallel_kwargs["eplb_window_size"])
        parallel_group.add_argument("--eplb-step-interval",
                                    **parallel_kwargs["eplb_step_interval"])
        parallel_group.add_argument("--eplb-log-balancedness",
                                    **parallel_kwargs["eplb_log_balancedness"])
644
        parallel_group.add_argument(
645
            "--max-parallel-loading-workers",
646
647
            **parallel_kwargs["max_parallel_loading_workers"])
        parallel_group.add_argument(
648
            "--ray-workers-use-nsight",
649
650
            **parallel_kwargs["ray_workers_use_nsight"])
        parallel_group.add_argument(
651
            "--disable-custom-all-reduce",
652
            **parallel_kwargs["disable_custom_all_reduce"])
653
654
655
656
        parallel_group.add_argument("--worker-cls",
                                    **parallel_kwargs["worker_cls"])
        parallel_group.add_argument("--worker-extension-cls",
                                    **parallel_kwargs["worker_extension_cls"])
657
658
659
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
            **parallel_kwargs["enable_multimodal_encoder_data_parallel"])
660

661
662
663
664
665
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
666
        )
667
668
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
        cache_group.add_argument("--gpu-memory-utilization",
669
                                 **cache_kwargs["gpu_memory_utilization"])
670
671
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
        cache_group.add_argument("--kv-cache-dtype",
672
                                 **cache_kwargs["cache_dtype"])
673
        cache_group.add_argument("--num-gpu-blocks-override",
674
675
676
677
678
                                 **cache_kwargs["num_gpu_blocks_override"])
        cache_group.add_argument("--enable-prefix-caching",
                                 **cache_kwargs["enable_prefix_caching"])
        cache_group.add_argument("--prefix-caching-hash-algo",
                                 **cache_kwargs["prefix_caching_hash_algo"])
679
        cache_group.add_argument("--cpu-offload-gb",
680
                                 **cache_kwargs["cpu_offload_gb"])
681
        cache_group.add_argument("--calculate-kv-scales",
682
683
                                 **cache_kwargs["calculate_kv_scales"])

684
        # Multimodal related configs
685
686
687
688
689
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
690
        multimodal_group.add_argument("--limit-mm-per-prompt",
691
                                      **multimodal_kwargs["limit_per_prompt"])
692
693
        multimodal_group.add_argument("--media-io-kwargs",
                                      **multimodal_kwargs["media_io_kwargs"])
694
        multimodal_group.add_argument(
695
            "--mm-processor-kwargs",
696
697
            **multimodal_kwargs["mm_processor_kwargs"])
        multimodal_group.add_argument(
698
            "--disable-mm-preprocessor-cache",
699
            **multimodal_kwargs["disable_mm_preprocessor_cache"])
700
701
702
        multimodal_group.add_argument(
            "--interleave-mm-strings",
            **multimodal_kwargs["interleave_mm_strings"])
703

704
        # LoRA related configs
705
706
707
708
709
710
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
711
            "--enable-lora",
712
            action=argparse.BooleanOptionalAction,
713
714
            help="If True, enable handling of LoRA adapters.")
        lora_group.add_argument("--enable-lora-bias",
715
                                **lora_kwargs["bias_enabled"])
716
717
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
        lora_group.add_argument("--max-lora-rank",
718
                                **lora_kwargs["max_lora_rank"])
719
        lora_group.add_argument("--lora-extra-vocab-size",
720
721
                                **lora_kwargs["lora_extra_vocab_size"])
        lora_group.add_argument(
722
            "--lora-dtype",
723
724
            **lora_kwargs["lora_dtype"],
        )
725
        lora_group.add_argument("--max-cpu-loras",
726
                                **lora_kwargs["max_cpu_loras"])
727
        lora_group.add_argument("--fully-sharded-loras",
728
                                **lora_kwargs["fully_sharded_loras"])
729
730
        lora_group.add_argument("--default-mm-loras",
                                **lora_kwargs["default_mm_loras"])
731
732
733
734
735
736
737
738

        # PromptAdapter related configs
        prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
        prompt_adapter_group = parser.add_argument_group(
            title="PromptAdapterConfig",
            description=PromptAdapterConfig.__doc__,
        )
        prompt_adapter_group.add_argument(
739
            "--enable-prompt-adapter",
740
            action=argparse.BooleanOptionalAction,
741
            help="If True, enable handling of PromptAdapters.")
742
        prompt_adapter_group.add_argument(
743
            "--max-prompt-adapters",
744
745
            **prompt_adapter_kwargs["max_prompt_adapters"])
        prompt_adapter_group.add_argument(
746
            "--max-prompt-adapter-token",
747
            **prompt_adapter_kwargs["max_prompt_adapter_token"])
748

749
750
751
752
753
754
        # Speculative arguments
        speculative_group = parser.add_argument_group(
            title="SpeculativeConfig",
            description=SpeculativeConfig.__doc__,
        )
        speculative_group.add_argument(
755
            "--speculative-config",
756
757
            type=json.loads,
            default=None,
758
759
            help="The configurations for speculative decoding. Should be a "
            "JSON string.")
760

761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
            **observability_kwargs["show_hidden_metrics_for_version"])
        observability_group.add_argument(
            "--otlp-traces-endpoint",
            **observability_kwargs["otlp_traces_endpoint"])
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
            ",".join(p)
            for p in permutations(get_args(DetailedTraceModules), r=2)
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
            **observability_kwargs["collect_detailed_traces"])
784

785
786
787
788
789
790
791
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
792
            "--max-num-batched-tokens",
793
            **scheduler_kwargs["max_num_batched_tokens"])
794
        scheduler_group.add_argument("--max-num-seqs",
795
796
797
798
799
800
801
                                     **scheduler_kwargs["max_num_seqs"])
        scheduler_group.add_argument(
            "--max-num-partial-prefills",
            **scheduler_kwargs["max_num_partial_prefills"])
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
            **scheduler_kwargs["max_long_partial_prefills"])
802
803
        scheduler_group.add_argument('--cuda-graph-sizes',
                                     **scheduler_kwargs["cuda_graph_sizes"])
804
805
806
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
            **scheduler_kwargs["long_prefill_token_threshold"])
807
        scheduler_group.add_argument("--num-lookahead-slots",
808
                                     **scheduler_kwargs["num_lookahead_slots"])
809
        scheduler_group.add_argument("--scheduler-delay-factor",
810
                                     **scheduler_kwargs["delay_factor"])
811
        scheduler_group.add_argument("--preemption-mode",
812
                                     **scheduler_kwargs["preemption_mode"])
813
        scheduler_group.add_argument("--num-scheduler-steps",
814
                                     **scheduler_kwargs["num_scheduler_steps"])
815
        scheduler_group.add_argument(
816
            "--multi-step-stream-outputs",
817
            **scheduler_kwargs["multi_step_stream_outputs"])
818
        scheduler_group.add_argument("--scheduling-policy",
819
                                     **scheduler_kwargs["policy"])
820
        scheduler_group.add_argument(
821
            "--enable-chunked-prefill",
822
            **scheduler_kwargs["enable_chunked_prefill"])
823
824
825
        scheduler_group.add_argument(
            "--disable-chunked-mm-input",
            **scheduler_kwargs["disable_chunked_mm_input"])
826
827
        scheduler_group.add_argument("--scheduler-cls",
                                     **scheduler_kwargs["scheduler_cls"])
828
829
830
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"])
831
832
        scheduler_group.add_argument("--async-scheduling",
                                     **scheduler_kwargs["async_scheduling"])
833
834

        # vLLM arguments
835
        vllm_kwargs = get_kwargs(VllmConfig)
836
837
838
839
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
840
841
842
843
844
845
846
847
        vllm_group.add_argument("--kv-transfer-config",
                                **vllm_kwargs["kv_transfer_config"])
        vllm_group.add_argument('--kv-events-config',
                                **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument("--compilation-config", "-O",
                                **vllm_kwargs["compilation_config"])
        vllm_group.add_argument("--additional-config",
                                **vllm_kwargs["additional_config"])
848

849
850
851
852
        # Other arguments
        parser.add_argument('--disable-log-stats',
                            action='store_true',
                            help='Disable logging statistics.')
853

854
        return parser
855
856

    @classmethod
857
    def from_cli_args(cls, args: argparse.Namespace):
858
859
860
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
861
862
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
863

864
    def create_model_config(self) -> ModelConfig:
865
866
867
868
869
870
871
872
873
874
875
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

876
        return ModelConfig(
877
            model=self.model,
878
            hf_config_path=self.hf_config_path,
879
            task=self.task,
880
            tokenizer=self.tokenizer,
881
882
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
883
            allowed_local_media_path=self.allowed_local_media_path,
884
885
886
887
888
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
889
            rope_theta=self.rope_theta,
890
            hf_token=self.hf_token,
891
            hf_overrides=self.hf_overrides,
892
893
894
895
896
897
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
898
            logprobs_mode=self.logprobs_mode,
899
            disable_sliding_window=self.disable_sliding_window,
900
            disable_cascade_attn=self.disable_cascade_attn,
901
            skip_tokenizer_init=self.skip_tokenizer_init,
902
            enable_prompt_embeds=self.enable_prompt_embeds,
903
            served_model_name=self.served_model_name,
904
            limit_mm_per_prompt=self.limit_mm_per_prompt,
905
            interleave_mm_strings=self.interleave_mm_strings,
906
            media_io_kwargs=self.media_io_kwargs,
907
            use_async_output_proc=not self.disable_async_output_proc,
908
            config_format=self.config_format,
909
            mm_processor_kwargs=self.mm_processor_kwargs,
910
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
911
912
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
913
            logits_processor_pattern=self.logits_processor_pattern,
914
            generation_config=self.generation_config,
915
            override_generation_config=self.override_generation_config,
916
            enable_sleep_mode=self.enable_sleep_mode,
917
            model_impl=self.model_impl,
918
            override_attention_dtype=self.override_attention_dtype,
919
        )
920

921
922
923
924
925
926
927
    def validate_tensorizer_args(self):
        from vllm.model_executor.model_loader.tensorizer import (
            TensorizerConfig)
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
                self.model_loader_extra_config["tensorizer_config"][
                    key] = self.model_loader_extra_config[key]
928

929
930
    def create_load_config(self) -> LoadConfig:

931
932
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
933

934
935
936
937
938
939
940
941
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
                    self.model_loader_extra_config.to_serializable())
            self.model_loader_extra_config["tensorizer_config"] = {}
            self.model_loader_extra_config["tensorizer_config"][
                "tensorizer_dir"] = self.model
            self.validate_tensorizer_args()
942

943
944
945
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
946
947
            device="cpu"
            if is_online_quantization(self.quantization) else None,
948
949
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
950
            use_tqdm_on_load=self.use_tqdm_on_load,
951
            pt_load_map_location=self.pt_load_map_location,
952
        )
953

954
955
956
957
958
959
960
961
962
963
964
965
966
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
967
        dictionary from the engine.
968
969
        """
        if self.speculative_config is None:
970
971
            return None

972
973
974
975
976
977
978
979
980
981
982
983
984
985
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
        speculative_config = SpeculativeConfig.from_dict(
            self.speculative_config)

        return speculative_config

986
987
988
989
990
991
992
993
994
995
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
996

997
998
999
1000
1001
1002
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1003
        current_platform.pre_register_and_update()
1004

1005
1006
        device_config = DeviceConfig(
            device=cast(Device, current_platform.device_type))
1007
1008
        model_config = self.create_model_config()

1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
1028
            self._set_default_args_v1(usage_context, model_config)
1029
1030
1031
1032
1033
1034
1035
1036
            # Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1
            if current_platform.is_cpu(
            ) and current_platform.get_cpu_architecture() in (
                    CpuArchEnum.POWERPC, CpuArchEnum.ARM):
                logger.info(
                    "Chunked prefill is not supported for ARM and POWER CPUs; "
                    "disabling it for V1 backend.")
                self.enable_chunked_prefill = False
1037
1038
        else:
            self._set_default_args_v0(model_config)
1039
1040
        assert self.enable_chunked_prefill is not None

1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
        if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
            assert self.enforce_eager, (
                "Cuda graph is not supported with DualChunkFlashAttention. "
                "To run the model in eager mode, set 'enforce_eager=True' "
                "or use '--enforce-eager' in the CLI.")
            assert current_platform.is_cuda(), (
                "DualChunkFlashAttention is only supported on CUDA platform.")
            assert not use_v1, (
                "DualChunkFlashAttention is not supported on V1 engine. "
                "To run the model in V0 engine, try set 'VLLM_USE_V1=0'")

1052
        cache_config = CacheConfig(
1053
            block_size=self.block_size,
1054
1055
1056
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1057
            is_attention_free=model_config.is_attention_free,
1058
1059
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1060
            enable_prefix_caching=self.enable_prefix_caching,
1061
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1062
            cpu_offload_gb=self.cpu_offload_gb,
1063
            calculate_kv_scales=self.calculate_kv_scales,
1064
        )
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
        data_parallel_external_lb = self.data_parallel_rank is not None
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
                "data_parallel_size_local must be 1 when data_parallel_rank "
                "is set")
            data_parallel_size_local = 1
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
        else:
            # Local DP size defaults to global DP size if not set.
            data_parallel_size_local = self.data_parallel_size
1088
1089
1090

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
                    "Using host IP %s as ray-based data parallel address",
                    host_ip)
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
                    self.data_parallel_backend)
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1105
1106
1107
1108
1109
1110
1111

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
        data_parallel_rpc_port = self.data_parallel_rpc_port if (
            self.data_parallel_rpc_port
            is not None) else ParallelConfig.data_parallel_rpc_port

1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
        if self.async_scheduling:
            # Async scheduling does not work with the uniprocess backend.
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "mp"
                logger.info("Using mp-based distributed executor backend "
                            "for async scheduling.")
            if self.distributed_executor_backend == "uni":
                raise ValueError("Async scheduling is not supported with "
                                 "uni-process backend.")
            if self.pipeline_parallel_size > 1:
                raise ValueError("Async scheduling is not supported with "
                                 "pipeline-parallel-size > 1.")

            # Currently, async scheduling does not support speculative decoding.
            # TODO(woosuk): Support it.
            if self.speculative_config is not None:
                raise ValueError(
                    "Currently, speculative decoding is not supported with "
                    "async scheduling.")

1132
        parallel_config = ParallelConfig(
1133
1134
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1135
            data_parallel_size=self.data_parallel_size,
1136
1137
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1138
1139
1140
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1141
            data_parallel_backend=self.data_parallel_backend,
1142
            enable_expert_parallel=self.enable_expert_parallel,
1143
1144
1145
1146
1147
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.num_redundant_experts,
            eplb_window_size=self.eplb_window_size,
            eplb_step_interval=self.eplb_step_interval,
            eplb_log_balancedness=self.eplb_log_balancedness,
1148
1149
1150
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1151
            placement_group=placement_group,
1152
1153
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1154
            worker_extension_cls=self.worker_extension_cls,
1155
1156
            enable_multimodal_encoder_data_parallel=self.
            enable_multimodal_encoder_data_parallel,
1157
        )
1158

1159
        speculative_config = self.create_speculative_config(
1160
1161
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1162
            enable_chunked_prefill=self.enable_chunked_prefill,
1163
            disable_log_stats=self.disable_log_stats,
1164
1165
        )

1166
        # Reminder: Please update docs/features/compatibility_matrix.md
1167
        # If the feature combo become valid
1168
1169
1170
1171
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1172
1173
1174
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1175
1176
1177
1178
1179
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1180
1181
1182
1183
1184
1185
1186
1187
1188

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1189
        scheduler_config = SchedulerConfig(
1190
            runner_type=model_config.runner_type,
1191
1192
1193
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1194
            cuda_graph_sizes=self.cuda_graph_sizes,
1195
            num_lookahead_slots=num_lookahead_slots,
1196
1197
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1198
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1199
            is_multimodal_model=model_config.is_multimodal_model,
1200
            preemption_mode=self.preemption_mode,
1201
            num_scheduler_steps=self.num_scheduler_steps,
1202
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1203
1204
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1205
            policy=self.scheduling_policy,
1206
            scheduler_cls=self.scheduler_cls,
1207
1208
1209
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1210
1211
            disable_hybrid_kv_cache_manager=self.
            disable_hybrid_kv_cache_manager,
1212
            async_scheduling=self.async_scheduling,
1213
        )
1214

1215
1216
1217
1218
1219
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
                "non multimodal model")

1220
        lora_config = LoRAConfig(
1221
            bias_enabled=self.enable_lora_bias,
1222
1223
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1224
            default_mm_loras=self.default_mm_loras,
1225
            fully_sharded_loras=self.fully_sharded_loras,
1226
1227
1228
1229
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1230

1231
1232
1233
1234
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1235
        load_config = self.create_load_config()
1236

1237
1238
1239
1240
1241
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1242
        decoding_config = DecodingConfig(
1243
1244
1245
1246
1247
            backend=self.guided_decoding_backend,
            disable_fallback=self.guided_decoding_disable_fallback,
            disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
            disable_additional_properties=\
                self.guided_decoding_disable_additional_properties,
1248
1249
            reasoning_backend=self.reasoning_parser
        )
1250

1251
        observability_config = ObservabilityConfig(
1252
1253
            show_hidden_metrics_for_version=(
                self.show_hidden_metrics_for_version),
1254
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1255
            collect_detailed_traces=self.collect_detailed_traces,
1256
        )
1257

1258
        config = VllmConfig(
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1269
            prompt_adapter_config=prompt_adapter_config,
1270
            compilation_config=self.compilation_config,
1271
            kv_transfer_config=self.kv_transfer_config,
1272
            kv_events_config=self.kv_events_config,
1273
            additional_config=self.additional_config,
1274
        )
1275

1276
1277
        return config

1278
1279
1280
1281
1282
1283
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1284
        if self.load_format == LoadFormat.SHARDED_STATE.value:
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

1296
        if self.preemption_mode != SchedulerConfig.preemption_mode:
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

1307
        if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
1308
1309
1310
1311
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

1312
        if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
1313
1314
1315
1316
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

1317
1318
        if self.guided_decoding_backend not in get_args(
                GuidedDecodingBackendV1):
1319
1320
1321
1322
            _raise_or_fallback(
                feature_name=
                f"--guided-decoding-backend={self.guided_decoding_backend}",
                recommend_to_remove=False)
1323
1324
1325
            return False

        # Need at least Ampere for now (FA support required).
1326
1327
1328
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1329
        if (current_platform.is_cuda()
1330
                and current_platform.get_device_capability()
1331
1332
1333
1334
1335
1336
1337
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
1338
1339
            supported = current_platform.is_kv_cache_dtype_supported(
                self.kv_cache_dtype)
1340
1341
1342
1343
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1344
1345
1346
1347
1348
1349
1350

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

1351
1352
1353
1354
1355
1356
        # No text embedding inputs so far.
        if self.enable_prompt_embeds:
            _raise_or_fallback(feature_name="--enable-prompt-embeds",
                               recommend_to_remove=False)
            return False

1357
        # No Mamba or Encoder-Decoder so far.
1358
1359
1360
1361
1362
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

Chen Zhang's avatar
Chen Zhang committed
1363
1364
1365
1366
1367
        # V1 mamba models are unoptimized.
        if model_config.has_inner_state and _warn_or_fallback(
                feature_name="Mamba"):
            return False

1368
1369
        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
1370
                != SchedulerConfig.max_num_partial_prefills
1371
                or self.max_long_partial_prefills
1372
                != SchedulerConfig.max_long_partial_prefills):
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

1383
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1384
1385
1386
1387
1388
1389
        if (self.speculative_config is not None
                and self.speculative_config.get("method") == "draft_model"):
            raise NotImplementedError(
                "Speculative decoding with draft model is not supported yet. "
                "Please consider using other speculative decoding methods "
                "such as ngram, medusa, eagle, or deepseek_mtp.")
1390

1391
        # No XFormers so far.
1392
        V1_BACKENDS = [
1393
1394
1395
1396
1397
1398
            "FLASH_ATTN_VLLM_V1",
            "FLASH_ATTN",
            "PALLAS",
            "PALLAS_VLLM_V1",
            "TRITON_ATTN_VLLM_V1",
            "TRITON_MLA",
1399
            "CUTLASS_MLA_VLLM_V1",
1400
1401
1402
            "FLASHMLA",
            "FLASHINFER",
            "FLASHINFER_VLLM_V1",
1403
            "ROCM_AITER_MLA",
1404
            "TORCH_SDPA_VLLM_V1",
1405
            "FLEX_ATTENTION",
1406
1407
1408
1409
1410
1411
1412
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1413
1414
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1415
1416
1417
1418
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1419
1420
1421
        #############################################################
        # Experimental Features - allow users to opt in.

1422
1423
1424
1425
1426
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1427
        if (self.pipeline_parallel_size > 1
1428
                and self.distributed_executor_backend
1429
1430
                not in (ParallelConfig.distributed_executor_backend, "ray",
                        "mp", "external_launcher")):
1431
            name = "Pipeline Parallelism without Ray distributed executor " \
1432
                    "or multiprocessing executor or external launcher"
1433
            _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1434
1435
            return False

1436
1437
1438
1439
        # The platform may be supported on V1, but off by default for now.
        if not current_platform.default_v1(  # noqa: SIM103
                model_config=model_config) and _warn_or_fallback(
                    current_platform.device_name):
1440
            return False
1441
1442
1443
1444
1445
1446
1447

        if (current_platform.is_cpu()
                and model_config.get_sliding_window() is not None):
            _raise_or_fallback(feature_name="sliding window (CPU backend)",
                               recommend_to_remove=False)
            return False

1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1468
                use_spec_decode = self.speculative_config is not None
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching:
            # Disable prefix caching for multimodal models for VLLM_V0.
            if model_config.is_multimodal_model:
                logger.warning(
                    "--enable-prefix-caching is not supported for multimodal "
                    "models in V0 and has been disabled.")
                self.enable_prefix_caching = False

            # VLLM_V0 only supports builtin hash algo for prefix caching.
1506
            if self.prefix_caching_hash_algo == "sha256":
1507
1508
1509
                raise ValueError(
                    "sha256 is not supported for prefix caching in V0 engine. "
                    "Please use 'builtin'.")
1510
1511
1512
1513
1514

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

1515
1516
    def _set_default_args_v1(self, usage_context: UsageContext,
                             model_config: ModelConfig) -> None:
1517
        """Set Default Arguments for V1 Engine."""
1518

1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
        # V1 always uses chunked prefills and prefix caching
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = True
        else:

            pooling_type = model_config.pooler_config.pooling_type

            # TODO: when encoder models are supported we'll have to
            # check for causal attention here.
            incremental_prefill_supported = (pooling_type is not None and
                                             pooling_type.lower() == "last")
1534

1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
            action = "Enabling" if \
                incremental_prefill_supported else "Disabling"

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)

        if not self.enable_chunked_prefill:
            self.max_num_batched_tokens = model_config.max_model_len
1547

1548
1549
1550
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1551
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1552

1553
1554
        # When no user override, set the default values based on the usage
        # context.
1555
        # Use different default values for different hardware.
1556
1557
1558
1559
1560
1561
1562

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1563
            device_memory = current_platform.get_device_total_memory()
1564
            device_name = current_platform.get_device_name().lower()
1565
1566
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1567
            device_memory = 0
1568

1569
1570
1571
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
1572
        from vllm.usage.usage_lib import UsageContext
1573
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1574
            # For GPUs like H100 and MI300x, use larger default values.
1575
1576
1577
1578
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1579
1580
1581
1582
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1583
1584
1585
1586
1587
1588
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1589
1590
1591
1592
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1593

1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
                    'V6E': 2048,
                    'V5E': 1024,
                    'V5P': 512,
                },
                UsageContext.OPENAI_API_SERVER: {
                    'V6E': 1024,
                    'V5E': 512,
                    'V5P': 256,
                }
            }

1609
1610
        # cpu specific default values.
        if current_platform.is_cpu():
1611
            world_size = self.pipeline_parallel_size * self.tensor_parallel_size
1612
            default_max_num_batched_tokens = {
1613
1614
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1615
1616
            }
            default_max_num_seqs = {
1617
1618
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1619
1620
            }

1621
        use_context_value = usage_context.value if usage_context else None
1622
1623
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
                if chip_name in default_max_num_batched_tokens_tpu[
                        usage_context]:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens_tpu[
                            usage_context][chip_name]
                else:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens[usage_context]
            else:
                self.max_num_batched_tokens = default_max_num_batched_tokens[
                    usage_context]
1637
            logger.debug(
1638
                "Setting max_num_batched_tokens to %d for %s usage context.",
1639
                self.max_num_batched_tokens, use_context_value)
1640

1641
1642
1643
        if (self.max_num_seqs is None
                and usage_context in default_max_num_seqs):
            self.max_num_seqs = default_max_num_seqs[usage_context]
1644
1645
1646

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1647

1648

1649
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1650
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1651
    """Arguments for asynchronous vLLM engine."""
1652
    disable_log_requests: bool = False
1653
1654

    @staticmethod
1655
1656
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1657
1658
1659
1660
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1661
1662
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1663
1664
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1665
                            help='Disable logging requests.')
1666
        current_platform.pre_register_and_update(parser)
1667
        return parser
1668
1669


1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1697
1698
1699
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1700

1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
    match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
    if match:
        decimal_multiplier = {
            'k': 10**3,
            'm': 10**6,
            'g': 10**9,
        }
        binary_multiplier = {
            'K': 2**10,
            'M': 2**20,
            'G': 2**30,
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
                raise argparse.ArgumentTypeError("Decimals are not allowed " \
                f"with binary suffixes like {suffix}. Did you mean to use " \
                f"{number}{suffix.lower()} instead?") from e

    # Regular plain number.
    return int(value)


1738
1739
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1740
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1741
1742
1743


def _async_engine_args_parser():
1744
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1745
                                        async_args_only=True)