arg_utils.py 87.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# yapf: disable
5
import argparse
6
import copy
7
import dataclasses
8
import functools
9
import json
10
import sys
11
from dataclasses import MISSING, dataclass, fields, is_dataclass
12
from itertools import permutations
13
14
15
from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List,
                    Literal, Optional, Type, TypeVar, Union, cast, get_args,
                    get_origin)
16

17
import huggingface_hub
18
import regex as re
19
import torch
20
from pydantic import TypeAdapter, ValidationError
21
from typing_extensions import TypeIs, deprecated
22

23
import vllm.envs as envs
24
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
25
26
27
                         ConfigType, ConvertOption, DetailedTraceModules,
                         Device, DeviceConfig, DistributedExecutorBackend,
                         EPLBConfig, HfOverrides, KVEventsConfig,
28
                         KVTransferConfig, LoadConfig, LogprobsMode,
29
                         LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
30
31
32
33
34
                         ModelDType, ObservabilityConfig, ParallelConfig,
                         PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
                         SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
                         StructuredOutputsConfig, TaskOption, TokenizerMode,
                         VllmConfig, get_attr_docs)
35
from vllm.config.multimodal import MMCacheType, MultiModalConfig
36
from vllm.config.parallel import ExpertPlacementStrategy
37
from vllm.config.utils import get_field
38
from vllm.logger import init_logger
39
from vllm.platforms import CpuArchEnum, current_platform
40
from vllm.plugins import load_general_plugins
41
from vllm.ray.lazy_utils import is_ray_initialized
42
from vllm.reasoning import ReasoningParserManager
43
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
44
from vllm.transformers_utils.config import get_model_path, is_interleaved
45
from vllm.transformers_utils.utils import check_gguf_file
46
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
Rui Qiao's avatar
Rui Qiao committed
47
                        GiB_bytes, get_ip, is_in_ray_actor)
48
from vllm.v1.sample.logits_processor import LogitsProcessor
49
50

# yapf: enable
51

52
53
54
if TYPE_CHECKING:
    from vllm.executor.executor_base import ExecutorBase
    from vllm.model_executor.layers.quantization import QuantizationMethods
55
    from vllm.model_executor.model_loader import LoadFormats
56
57
58
59
    from vllm.usage.usage_lib import UsageContext
else:
    ExecutorBase = Any
    QuantizationMethods = Any
60
    LoadFormats = Any
61
62
    UsageContext = Any

63
64
logger = init_logger(__name__)

65
66
67
68
69
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]

70

71
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
72

73
    def _parse_type(val: str) -> T:
74
75
76
77
78
        try:
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
                f"Value {val} cannot be converted to {return_type}.") from e
79

80
81
82
83
84
85
86
87
88
89
90
    return _parse_type


def optional_type(
        return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:

    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        return parse_type(return_type)(val)

91
    return _optional_type
92
93


94
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
95
    if not re.match(r"(?s)^\s*{.*}\s*$", val):
96
        return str(val)
97
    return optional_type(json.loads)(val)
98
99


100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


115
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
116
117
118
119
    """Get the `type` and `choices` from a `Literal` type hint in `type_hints`.

    If `type_hints` also contains `str`, we use `metavar` instead of `choices`.
    """
120
    type_hint = get_type(type_hints, Literal)
121
122
123
    options = get_args(type_hint)
    option_type = type(options[0])
    if not all(isinstance(option, option_type) for option in options):
124
        raise ValueError(
125
126
127
128
            "All options must be of the same type. "
            f"Got {options} with types {[type(c) for c in options]}")
    kwarg = "metavar" if contains_type(type_hints, str) else "choices"
    return {"type": option_type, kwarg: sorted(options)}
129
130


131
132
133
134
135
def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
    """Extract type hints from Annotated or Union type hints."""
    type_hints: set[TypeHint] = set()
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Annotated:
        type_hints.update(get_type_hints(args[0]))
    elif origin is Union:
        for arg in args:
            type_hints.update(get_type_hints(arg))
    else:
        type_hints.add(type_hint)

    return type_hints


153
154
155
156
def is_online_quantization(quantization: Any) -> bool:
    return quantization in ["inc"]


157
158
159
160
161
162
163
NEEDS_HELP = (
    "--help" in (argv := sys.argv)  # vllm SUBCOMMAND --help
    or (argv0 := argv[0]).endswith("mkdocs")  # mkdocs SUBCOMMAND
    or argv0.endswith("mkdocs/__main__.py")  # python -m mkdocs SUBCOMMAND
)


164
165
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
166
167
    # Save time only getting attr docs if we're generating help text
    cls_docs = get_attr_docs(cls) if NEEDS_HELP else {}
168
169
    kwargs = {}
    for field in fields(cls):
170
        # Get the set of possible types for the field
171
        type_hints: set[TypeHint] = get_type_hints(field.type)
172
173
174
175
176

        # If the field is a dataclass, we can use the model_validate_json
        generator = (th for th in type_hints if is_dataclass(th))
        dataclass_cls = next(generator, None)

177
        # Get the default value of the field
178
179
180
        if field.default is not MISSING:
            default = field.default
        elif field.default_factory is not MISSING:
181
            default = field.default_factory()
182
183
184

        # Get the help text for the field
        name = field.name
185
        help = cls_docs.get(name, "").strip()
186
187
188
189
190
191
192
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Set other kwargs based on the type hints
193
194
        json_tip = ("Should either be a valid JSON string or JSON keys passed "
                    "individually.")
195
        if dataclass_cls is not None:
196
197
198
199
200
201
202
203

            def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
                try:
                    return TypeAdapter(cls).validate_json(val)
                except ValidationError as e:
                    raise argparse.ArgumentTypeError(repr(e)) from e

            kwargs[name]["type"] = parse_dataclass
204
            kwargs[name]["help"] += f"\n\n{json_tip}"
205
        elif contains_type(type_hints, bool):
206
207
208
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
209
            kwargs[name].update(literal_to_kwargs(type_hints))
210
211
212
213
214
215
216
217
218
219
220
221
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
                f"type. Got {types}.")
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
222
223
224
225
226
227
            list_type = types[0]
            if get_origin(list_type) is Union:
                msg = "List type must contain str if it is a Union."
                assert str in get_args(list_type), msg
                list_type = str
            kwargs[name]["type"] = list_type
228
229
230
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
231
            # Special case for large integers
232
233
234
235
236
237
            human_readable_ints = {
                "max_model_len",
                "max_num_batched_tokens",
                "kv_cache_memory_bytes",
            }
            if name in human_readable_ints:
238
                kwargs[name]["type"] = human_readable_int
239
                kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
240
241
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
242
243
244
        elif (contains_type(type_hints, dict)
              and (contains_type(type_hints, str)
                   or any(is_not_builtin(th) for th in type_hints))):
245
            kwargs[name]["type"] = union_dict_and_str
246
        elif contains_type(type_hints, dict):
247
            kwargs[name]["type"] = parse_type(json.loads)
248
            kwargs[name]["help"] += f"\n\n{json_tip}"
249
250
251
252
253
254
255
        elif (contains_type(type_hints, str)
              or any(is_not_builtin(th) for th in type_hints)):
            kwargs[name]["type"] = str
        else:
            raise ValueError(
                f"Unsupported type {type_hints} for argument {name}.")

256
257
258
259
260
        # If the type hint was a sequence of literals, use the helper function
        # to update the type and choices
        if get_origin(kwargs[name].get("type")) is Literal:
            kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))

261
262
263
264
265
266
267
        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
268
269


270
271
272
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    """Return argparse kwargs for the given Config dataclass.

273
274
275
    If `--help` or `mkdocs` are not present in the command line command, the
    attribute documentation will not be included in the help output.

276
277
278
279
280
281
282
    The heavy computation is cached via functools.lru_cache, and a deep copy
    is returned so callers can mutate the dictionary without affecting the
    cached version.
    """
    return copy.deepcopy(_compute_kwargs(cls))


283
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
284
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
285
    """Arguments for vLLM engine."""
286
287
288
289
290
    model: str = ModelConfig.model
    served_model_name: Optional[Union[
        str, List[str]]] = ModelConfig.served_model_name
    tokenizer: Optional[str] = ModelConfig.tokenizer
    hf_config_path: Optional[str] = ModelConfig.hf_config_path
291
292
293
    runner: RunnerOption = ModelConfig.runner
    convert: ConvertOption = ModelConfig.convert
    task: Optional[TaskOption] = ModelConfig.task
294
    skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
295
    enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
296
297
298
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
    trust_remote_code: bool = ModelConfig.trust_remote_code
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
299
    download_dir: Optional[str] = LoadConfig.download_dir
300
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
301
    load_format: Union[str, LoadFormats] = LoadConfig.load_format
302
303
    config_format: str = ModelConfig.config_format
    dtype: ModelDType = ModelConfig.dtype
304
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
305
306
    seed: Optional[int] = ModelConfig.seed
    max_model_len: Optional[int] = ModelConfig.max_model_len
307
308
    cuda_graph_sizes: list[int] = get_field(SchedulerConfig,
                                            "cuda_graph_sizes")
309
310
311
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
312
    distributed_executor_backend: Optional[Union[
313
        str, DistributedExecutorBackend,
314
        Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
315
    # number of P/D disaggregation (or other disaggregation) workers
316
317
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
318
319
    decode_context_parallel_size: int = \
        ParallelConfig.decode_context_parallel_size
320
    data_parallel_size: int = ParallelConfig.data_parallel_size
321
    data_parallel_rank: Optional[int] = None
322
    data_parallel_start_rank: Optional[int] = None
323
324
325
    data_parallel_size_local: Optional[int] = None
    data_parallel_address: Optional[str] = None
    data_parallel_rpc_port: Optional[int] = None
326
    data_parallel_hybrid_lb: bool = False
Rui Qiao's avatar
Rui Qiao committed
327
    data_parallel_backend: str = ParallelConfig.data_parallel_backend
328
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
329
330
331
    enable_dbo: bool = ParallelConfig.enable_dbo
    dbo_decode_token_threshold: int = \
        ParallelConfig.dbo_decode_token_threshold
332
    eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
333
    enable_eplb: bool = ParallelConfig.enable_eplb
334
335
    expert_placement_strategy: ExpertPlacementStrategy = \
        ParallelConfig.expert_placement_strategy
336
337
    _api_process_count: int = ParallelConfig._api_process_count
    _api_process_rank: int = ParallelConfig._api_process_rank
338
339
340
341
    num_redundant_experts: int = EPLBConfig.num_redundant_experts
    eplb_window_size: int = EPLBConfig.window_size
    eplb_step_interval: int = EPLBConfig.step_interval
    eplb_log_balancedness: bool = EPLBConfig.log_balancedness
342
343
    max_parallel_loading_workers: Optional[
        int] = ParallelConfig.max_parallel_loading_workers
344
345
346
347
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
    prefix_caching_hash_algo: PrefixCachingHashAlgo = \
        CacheConfig.prefix_caching_hash_algo
348
349
    disable_sliding_window: bool = ModelConfig.disable_sliding_window
    disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
350
351
352
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
353
    kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
354
355
356
357
358
359
360
    max_num_batched_tokens: Optional[
        int] = SchedulerConfig.max_num_batched_tokens
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
    long_prefill_token_threshold: int = \
        SchedulerConfig.long_prefill_token_threshold
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
361
    max_logprobs: int = ModelConfig.max_logprobs
362
    logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
363
    disable_log_stats: bool = False
364
365
366
367
368
    revision: Optional[str] = ModelConfig.revision
    code_revision: Optional[str] = ModelConfig.code_revision
    rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
    rope_theta: Optional[float] = ModelConfig.rope_theta
    hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
369
    hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
370
371
372
373
    tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
    quantization: Optional[QuantizationMethods] = ModelConfig.quantization
    enforce_eager: bool = ModelConfig.enforce_eager
    max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
374
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
375
    limit_mm_per_prompt: dict[str, int] = \
376
        get_field(MultiModalConfig, "limit_per_prompt")
377
    interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
378
379
380
    media_io_kwargs: dict[str, dict[str,
                                    Any]] = get_field(MultiModalConfig,
                                                      "media_io_kwargs")
381
382
    mm_processor_kwargs: Optional[Dict[str, Any]] = \
        MultiModalConfig.mm_processor_kwargs
383
    disable_mm_preprocessor_cache: bool = False  # DEPRECATED
384
    mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
385
386
387
388
    mm_processor_cache_type: Optional[MMCacheType] = \
        MultiModalConfig.mm_processor_cache_type
    mm_shm_cache_max_object_size_mb: int = \
        MultiModalConfig.mm_shm_cache_max_object_size_mb
389
    mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
390
    io_processor_plugin: Optional[str] = None
391
    skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
392
    # LoRA fields
393
    enable_lora: bool = False
394
395
396
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
397
398
    default_mm_loras: Optional[Dict[str, str]] = \
        LoRAConfig.default_mm_loras
399
400
401
402
403
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size

404
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
405
406
    num_gpu_blocks_override: Optional[
        int] = CacheConfig.num_gpu_blocks_override
407
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
408
409
    model_loader_extra_config: dict = \
        get_field(LoadConfig, "model_loader_extra_config")
410
411
    ignore_patterns: Optional[Union[str,
                                    List[str]]] = LoadConfig.ignore_patterns
412
    preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
413

414
415
416
417
    scheduler_delay_factor: float = SchedulerConfig.delay_factor
    enable_chunked_prefill: Optional[
        bool] = SchedulerConfig.enable_chunked_prefill
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
418

419
420
421
    disable_hybrid_kv_cache_manager: bool = (
        SchedulerConfig.disable_hybrid_kv_cache_manager)

422
423
424
425
426
427
428
429
430
    structured_outputs_config: StructuredOutputsConfig = get_field(
        VllmConfig, "structured_outputs_config")
    reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
    # Deprecated guided decoding fields
    guided_decoding_backend: Optional[str] = None
    guided_decoding_disable_fallback: Optional[bool] = None
    guided_decoding_disable_any_whitespace: Optional[bool] = None
    guided_decoding_disable_additional_properties: Optional[bool] = None

431
432
    logits_processor_pattern: Optional[
        str] = ModelConfig.logits_processor_pattern
433

434
    speculative_config: Optional[Dict[str, Any]] = None
435

436
437
438
439
440
441
    show_hidden_metrics_for_version: Optional[str] = \
        ObservabilityConfig.show_hidden_metrics_for_version
    otlp_traces_endpoint: Optional[str] = \
        ObservabilityConfig.otlp_traces_endpoint
    collect_detailed_traces: Optional[list[DetailedTraceModules]] = \
        ObservabilityConfig.collect_detailed_traces
442
    disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
443
444
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
445

446
    pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
447
448
    override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
        ModelConfig.override_pooler_config
449
450
    compilation_config: CompilationConfig = \
        get_field(VllmConfig, "compilation_config")
451
452
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
453

454
    kv_transfer_config: Optional[KVTransferConfig] = None
455
    kv_events_config: Optional[KVEventsConfig] = None
456

457
458
459
460
461
    generation_config: str = ModelConfig.generation_config
    enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
    override_generation_config: dict[str, Any] = \
        get_field(ModelConfig, "override_generation_config")
    model_impl: str = ModelConfig.model_impl
462
    override_attention_dtype: str = ModelConfig.override_attention_dtype
463

464
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
465
466
    mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
    mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
467

468
469
    additional_config: dict[str, Any] = \
        get_field(VllmConfig, "additional_config")
470

471
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
472
    pt_load_map_location: str = LoadConfig.pt_load_map_location
473

474
475
    # DEPRECATED
    enable_multimodal_encoder_data_parallel: bool = False
476

477
478
479
480
    logits_processors: Optional[list[Union[
        str, type[LogitsProcessor]]]] = ModelConfig.logits_processors
    """Custom logitproc types"""

481
482
    async_scheduling: bool = SchedulerConfig.async_scheduling

483
484
485
    kv_sharing_fast_prefill: bool = \
        CacheConfig.kv_sharing_fast_prefill

486
    def __post_init__(self):
487
488
489
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
490
491
492
        if isinstance(self.compilation_config, dict):
            self.compilation_config = CompilationConfig(
                **self.compilation_config)
493
        if isinstance(self.eplb_config, dict):
494
            self.eplb_config = EPLBConfig(**self.eplb_config)
495
        # Setup plugins
496
497
        from vllm.plugins import load_general_plugins
        load_general_plugins()
498
499
500
501
502
503
504
        # when use hf offline,replace model id to local model path
        if huggingface_hub.constants.HF_HUB_OFFLINE:
            model_id = self.model
            self.model = get_model_path(self.model, self.revision)
            logger.info(
                "HF_HUB_OFFLINE is True, replace model_id [%s] " \
                "to model_path [%s]",model_id, self.model)
505
506

    @staticmethod
507
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
508
        """Shared CLI arguments for vLLM engine."""
509

510
        # Model arguments
511
512
513
514
515
        model_kwargs = get_kwargs(ModelConfig)
        model_group = parser.add_argument_group(
            title="ModelConfig",
            description=ModelConfig.__doc__,
        )
Reid's avatar
Reid committed
516
        if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]):
517
            model_group.add_argument("--model", **model_kwargs["model"])
518
519
520
521
522
        model_group.add_argument("--runner", **model_kwargs["runner"])
        model_group.add_argument("--convert", **model_kwargs["convert"])
        model_group.add_argument("--task",
                                 **model_kwargs["task"],
                                 deprecated=True)
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
        model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
        model_group.add_argument("--tokenizer-mode",
                                 **model_kwargs["tokenizer_mode"])
        model_group.add_argument("--trust-remote-code",
                                 **model_kwargs["trust_remote_code"])
        model_group.add_argument("--dtype", **model_kwargs["dtype"])
        model_group.add_argument("--seed", **model_kwargs["seed"])
        model_group.add_argument("--hf-config-path",
                                 **model_kwargs["hf_config_path"])
        model_group.add_argument("--allowed-local-media-path",
                                 **model_kwargs["allowed_local_media_path"])
        model_group.add_argument("--revision", **model_kwargs["revision"])
        model_group.add_argument("--code-revision",
                                 **model_kwargs["code_revision"])
        model_group.add_argument("--rope-scaling",
                                 **model_kwargs["rope_scaling"])
        model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
        model_group.add_argument("--tokenizer-revision",
                                 **model_kwargs["tokenizer_revision"])
        model_group.add_argument("--max-model-len",
                                 **model_kwargs["max_model_len"])
        model_group.add_argument("--quantization", "-q",
                                 **model_kwargs["quantization"])
        model_group.add_argument("--enforce-eager",
                                 **model_kwargs["enforce_eager"])
        model_group.add_argument("--max-seq-len-to-capture",
                                 **model_kwargs["max_seq_len_to_capture"])
        model_group.add_argument("--max-logprobs",
                                 **model_kwargs["max_logprobs"])
552
553
        model_group.add_argument("--logprobs-mode",
                                 **model_kwargs["logprobs_mode"])
554
555
556
557
558
559
        model_group.add_argument("--disable-sliding-window",
                                 **model_kwargs["disable_sliding_window"])
        model_group.add_argument("--disable-cascade-attn",
                                 **model_kwargs["disable_cascade_attn"])
        model_group.add_argument("--skip-tokenizer-init",
                                 **model_kwargs["skip_tokenizer_init"])
560
561
        model_group.add_argument("--enable-prompt-embeds",
                                 **model_kwargs["enable_prompt_embeds"])
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        model_group.add_argument("--served-model-name",
                                 **model_kwargs["served_model_name"])
        # This one is a special case because it is the
        # opposite of ModelConfig.use_async_output_proc
        model_group.add_argument(
            "--disable-async-output-proc",
            action="store_true",
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
        model_group.add_argument("--config-format",
                                 **model_kwargs["config_format"])
        # This one is a special case because it can bool
        # or str. TODO: Handle this in get_kwargs
        model_group.add_argument("--hf-token",
                                 type=str,
                                 nargs="?",
                                 const=True,
                                 default=model_kwargs["hf_token"]["default"],
                                 help=model_kwargs["hf_token"]["help"])
        model_group.add_argument("--hf-overrides",
                                 **model_kwargs["hf_overrides"])
584
585
        model_group.add_argument("--pooler-config",
                                 **model_kwargs["pooler_config"])
586
        model_group.add_argument("--override-pooler-config",
587
588
                                 **model_kwargs["override_pooler_config"],
                                 deprecated=True)
589
590
591
592
593
594
595
596
        model_group.add_argument("--logits-processor-pattern",
                                 **model_kwargs["logits_processor_pattern"])
        model_group.add_argument("--generation-config",
                                 **model_kwargs["generation_config"])
        model_group.add_argument("--override-generation-config",
                                 **model_kwargs["override_generation_config"])
        model_group.add_argument("--enable-sleep-mode",
                                 **model_kwargs["enable_sleep_mode"])
597
        model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
598
599
        model_group.add_argument("--override-attention-dtype",
                                 **model_kwargs["override_attention_dtype"])
600
601
        model_group.add_argument("--logits-processors",
                                 **model_kwargs["logits_processors"])
602
603
        model_group.add_argument("--io-processor-plugin",
                                 **model_kwargs["io_processor_plugin"])
604

605
606
607
608
609
610
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
611
        load_group.add_argument("--load-format", **load_kwargs["load_format"])
612
        load_group.add_argument("--download-dir",
613
                                **load_kwargs["download_dir"])
614
615
        load_group.add_argument("--safetensors-load-strategy",
                                **load_kwargs["safetensors_load_strategy"])
616
        load_group.add_argument("--model-loader-extra-config",
617
                                **load_kwargs["model_loader_extra_config"])
618
619
620
        load_group.add_argument("--ignore-patterns",
                                **load_kwargs["ignore_patterns"])
        load_group.add_argument("--use-tqdm-on-load",
621
                                **load_kwargs["use_tqdm_on_load"])
622
623
        load_group.add_argument('--pt-load-map-location',
                                **load_kwargs["pt_load_map_location"])
624

625
626
627
628
629
        # Structured outputs arguments
        structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
        structured_outputs_group = parser.add_argument_group(
            title="StructuredOutputsConfig",
            description=StructuredOutputsConfig.__doc__,
630
        )
631
        structured_outputs_group.add_argument(
632
            "--reasoning-parser",
633
            # This choice is a special case because it's not static
634
            choices=list(ReasoningParserManager.reasoning_parsers),
635
636
637
638
639
640
641
642
643
644
645
646
647
            **structured_outputs_kwargs["reasoning_parser"])
        # Deprecated guided decoding arguments
        for arg, type in [
            ("--guided-decoding-backend", str),
            ("--guided-decoding-disable-fallback", bool),
            ("--guided-decoding-disable-any-whitespace", bool),
            ("--guided-decoding-disable-additional-properties", bool),
        ]:
            structured_outputs_group.add_argument(
                arg,
                type=type,
                help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
                deprecated=True)
648

649
        # Parallel arguments
650
651
652
653
654
655
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
656
            "--distributed-executor-backend",
657
658
            **parallel_kwargs["distributed_executor_backend"])
        parallel_group.add_argument(
659
            "--pipeline-parallel-size", "-pp",
660
            **parallel_kwargs["pipeline_parallel_size"])
661
        parallel_group.add_argument("--tensor-parallel-size", "-tp",
662
                                    **parallel_kwargs["tensor_parallel_size"])
663
664
665
        parallel_group.add_argument(
            "--decode-context-parallel-size", "-dcp",
            **parallel_kwargs["decode_context_parallel_size"])
666
        parallel_group.add_argument("--data-parallel-size", "-dp",
667
                                    **parallel_kwargs["data_parallel_size"])
668
669
670
671
672
673
        parallel_group.add_argument(
            '--data-parallel-rank',
            '-dpn',
            type=int,
            help='Data parallel rank of this instance. '
            'When set, enables external load balancer mode.')
674
675
676
677
678
        parallel_group.add_argument('--data-parallel-start-rank',
                                    '-dpr',
                                    type=int,
                                    help='Starting data parallel rank '
                                    'for secondary nodes.')
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
        parallel_group.add_argument('--data-parallel-size-local',
                                    '-dpl',
                                    type=int,
                                    help='Number of data parallel replicas '
                                    'to run on this node.')
        parallel_group.add_argument('--data-parallel-address',
                                    '-dpa',
                                    type=str,
                                    help='Address of data parallel cluster '
                                    'head-node.')
        parallel_group.add_argument('--data-parallel-rpc-port',
                                    '-dpp',
                                    type=int,
                                    help='Port for data parallel RPC '
                                    'communication.')
Rui Qiao's avatar
Rui Qiao committed
694
695
696
697
698
699
        parallel_group.add_argument('--data-parallel-backend',
                                    '-dpb',
                                    type=str,
                                    default='mp',
                                    help='Backend for data parallel, either '
                                    '"mp" or "ray".')
700
701
702
        parallel_group.add_argument(
            "--data-parallel-hybrid-lb",
            **parallel_kwargs["data_parallel_hybrid_lb"])
703
        parallel_group.add_argument(
704
            "--enable-expert-parallel",
705
            **parallel_kwargs["enable_expert_parallel"])
706
707
708
709
710
        parallel_group.add_argument("--enable-dbo",
                                    **parallel_kwargs["enable_dbo"])
        parallel_group.add_argument(
            "--dbo-decode-token-threshold",
            **parallel_kwargs["dbo_decode_token_threshold"])
711
712
        parallel_group.add_argument("--enable-eplb",
                                    **parallel_kwargs["enable_eplb"])
713
714
        parallel_group.add_argument("--eplb-config",
                                    **parallel_kwargs["eplb_config"])
715
716
717
        parallel_group.add_argument(
            "--expert-placement-strategy",
            **parallel_kwargs["expert_placement_strategy"])
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
        parallel_group.add_argument(
            "--num-redundant-experts",
            type=int,
            help=
            "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
            deprecated=True)
        parallel_group.add_argument(
            "--eplb-window-size",
            type=int,
            help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
            deprecated=True)
        parallel_group.add_argument(
            "--eplb-step-interval",
            type=int,
            help=
            "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
            deprecated=True)
        parallel_group.add_argument(
            "--eplb-log-balancedness",
            action=argparse.BooleanOptionalAction,
            help=
            "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
            deprecated=True)

742
        parallel_group.add_argument(
743
            "--max-parallel-loading-workers",
744
745
            **parallel_kwargs["max_parallel_loading_workers"])
        parallel_group.add_argument(
746
            "--ray-workers-use-nsight",
747
748
            **parallel_kwargs["ray_workers_use_nsight"])
        parallel_group.add_argument(
749
            "--disable-custom-all-reduce",
750
            **parallel_kwargs["disable_custom_all_reduce"])
751
752
753
754
        parallel_group.add_argument("--worker-cls",
                                    **parallel_kwargs["worker_cls"])
        parallel_group.add_argument("--worker-extension-cls",
                                    **parallel_kwargs["worker_extension_cls"])
755
756
        parallel_group.add_argument(
            "--enable-multimodal-encoder-data-parallel",
757
758
            action="store_true",
            deprecated=True)
759

760
761
762
763
764
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
765
        )
766
767
        cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
        cache_group.add_argument("--gpu-memory-utilization",
768
                                 **cache_kwargs["gpu_memory_utilization"])
769
770
        cache_group.add_argument("--kv-cache-memory-bytes",
                                 **cache_kwargs["kv_cache_memory_bytes"])
771
772
        cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
        cache_group.add_argument("--kv-cache-dtype",
773
                                 **cache_kwargs["cache_dtype"])
774
        cache_group.add_argument("--num-gpu-blocks-override",
775
776
777
778
779
                                 **cache_kwargs["num_gpu_blocks_override"])
        cache_group.add_argument("--enable-prefix-caching",
                                 **cache_kwargs["enable_prefix_caching"])
        cache_group.add_argument("--prefix-caching-hash-algo",
                                 **cache_kwargs["prefix_caching_hash_algo"])
780
        cache_group.add_argument("--cpu-offload-gb",
781
                                 **cache_kwargs["cpu_offload_gb"])
782
        cache_group.add_argument("--calculate-kv-scales",
783
                                 **cache_kwargs["calculate_kv_scales"])
784
785
        cache_group.add_argument("--kv-sharing-fast-prefill",
                                 **cache_kwargs["kv_sharing_fast_prefill"])
786
787
788
789
        cache_group.add_argument("--mamba-cache-dtype",
                                 **cache_kwargs["mamba_cache_dtype"])
        cache_group.add_argument("--mamba-ssm-cache-dtype",
                                 **cache_kwargs["mamba_ssm_cache_dtype"])
790

791
        # Multimodal related configs
792
793
794
795
796
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
797
        multimodal_group.add_argument("--limit-mm-per-prompt",
798
                                      **multimodal_kwargs["limit_per_prompt"])
799
800
        multimodal_group.add_argument("--media-io-kwargs",
                                      **multimodal_kwargs["media_io_kwargs"])
801
        multimodal_group.add_argument(
802
            "--mm-processor-kwargs",
803
804
            **multimodal_kwargs["mm_processor_kwargs"])
        multimodal_group.add_argument(
805
806
807
            "--mm-processor-cache-gb",
            **multimodal_kwargs["mm_processor_cache_gb"])
        multimodal_group.add_argument("--disable-mm-preprocessor-cache",
808
                                      action="store_true",
809
                                      deprecated=True)
810
811
812
813
814
815
        multimodal_group.add_argument(
            "--mm-processor-cache-type",
            **multimodal_kwargs["mm_processor_cache_type"])
        multimodal_group.add_argument(
            "--mm-shm-cache-max-object-size-mb",
            **multimodal_kwargs["mm_shm_cache_max_object_size_mb"])
816
817
        multimodal_group.add_argument(
            "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"])
818
819
820
        multimodal_group.add_argument(
            "--interleave-mm-strings",
            **multimodal_kwargs["interleave_mm_strings"])
821
822
        multimodal_group.add_argument("--skip-mm-profiling",
                                      **multimodal_kwargs["skip_mm_profiling"])
823

824
        # LoRA related configs
825
826
827
828
829
830
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
831
            "--enable-lora",
832
            action=argparse.BooleanOptionalAction,
833
834
            help="If True, enable handling of LoRA adapters.")
        lora_group.add_argument("--enable-lora-bias",
835
                                **lora_kwargs["bias_enabled"])
836
837
        lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
        lora_group.add_argument("--max-lora-rank",
838
                                **lora_kwargs["max_lora_rank"])
839
        lora_group.add_argument("--lora-extra-vocab-size",
840
841
                                **lora_kwargs["lora_extra_vocab_size"])
        lora_group.add_argument(
842
            "--lora-dtype",
843
844
            **lora_kwargs["lora_dtype"],
        )
845
        lora_group.add_argument("--max-cpu-loras",
846
                                **lora_kwargs["max_cpu_loras"])
847
        lora_group.add_argument("--fully-sharded-loras",
848
                                **lora_kwargs["fully_sharded_loras"])
849
850
        lora_group.add_argument("--default-mm-loras",
                                **lora_kwargs["default_mm_loras"])
851

852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
        # Observability arguments
        observability_kwargs = get_kwargs(ObservabilityConfig)
        observability_group = parser.add_argument_group(
            title="ObservabilityConfig",
            description=ObservabilityConfig.__doc__,
        )
        observability_group.add_argument(
            "--show-hidden-metrics-for-version",
            **observability_kwargs["show_hidden_metrics_for_version"])
        observability_group.add_argument(
            "--otlp-traces-endpoint",
            **observability_kwargs["otlp_traces_endpoint"])
        # TODO: generalise this special case
        choices = observability_kwargs["collect_detailed_traces"]["choices"]
        metavar = f"{{{','.join(choices)}}}"
        observability_kwargs["collect_detailed_traces"]["metavar"] = metavar
        observability_kwargs["collect_detailed_traces"]["choices"] += [
            ",".join(p)
            for p in permutations(get_args(DetailedTraceModules), r=2)
        ]
        observability_group.add_argument(
            "--collect-detailed-traces",
            **observability_kwargs["collect_detailed_traces"])
875

876
877
878
879
880
881
882
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
883
            "--max-num-batched-tokens",
884
            **scheduler_kwargs["max_num_batched_tokens"])
885
        scheduler_group.add_argument("--max-num-seqs",
886
887
888
889
890
891
892
                                     **scheduler_kwargs["max_num_seqs"])
        scheduler_group.add_argument(
            "--max-num-partial-prefills",
            **scheduler_kwargs["max_num_partial_prefills"])
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
            **scheduler_kwargs["max_long_partial_prefills"])
893
894
        scheduler_group.add_argument('--cuda-graph-sizes',
                                     **scheduler_kwargs["cuda_graph_sizes"])
895
896
897
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
            **scheduler_kwargs["long_prefill_token_threshold"])
898
        scheduler_group.add_argument("--num-lookahead-slots",
899
                                     **scheduler_kwargs["num_lookahead_slots"])
900
        scheduler_group.add_argument("--scheduler-delay-factor",
901
                                     **scheduler_kwargs["delay_factor"])
902
        scheduler_group.add_argument("--preemption-mode",
903
                                     **scheduler_kwargs["preemption_mode"])
904
905
        # multi-step scheduling has been removed; corresponding arguments
        # are no longer supported.
906
        scheduler_group.add_argument("--scheduling-policy",
907
                                     **scheduler_kwargs["policy"])
908
        scheduler_group.add_argument(
909
            "--enable-chunked-prefill",
910
            **scheduler_kwargs["enable_chunked_prefill"])
911
912
913
        scheduler_group.add_argument(
            "--disable-chunked-mm-input",
            **scheduler_kwargs["disable_chunked_mm_input"])
914
915
        scheduler_group.add_argument("--scheduler-cls",
                                     **scheduler_kwargs["scheduler_cls"])
916
917
918
        scheduler_group.add_argument(
            "--disable-hybrid-kv-cache-manager",
            **scheduler_kwargs["disable_hybrid_kv_cache_manager"])
919
920
        scheduler_group.add_argument("--async-scheduling",
                                     **scheduler_kwargs["async_scheduling"])
921
922

        # vLLM arguments
923
        vllm_kwargs = get_kwargs(VllmConfig)
924
925
926
927
        vllm_group = parser.add_argument_group(
            title="VllmConfig",
            description=VllmConfig.__doc__,
        )
928
929
930
931
        # We construct SpeculativeConfig using fields from other configs in
        # create_engine_config. So we set the type to a JSON string here to
        # delay the Pydantic validation that comes with SpeculativeConfig.
        vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads)
932
933
        vllm_group.add_argument("--speculative-config",
                                **vllm_kwargs["speculative_config"])
934
935
936
937
938
939
940
941
        vllm_group.add_argument("--kv-transfer-config",
                                **vllm_kwargs["kv_transfer_config"])
        vllm_group.add_argument('--kv-events-config',
                                **vllm_kwargs["kv_events_config"])
        vllm_group.add_argument("--compilation-config", "-O",
                                **vllm_kwargs["compilation_config"])
        vllm_group.add_argument("--additional-config",
                                **vllm_kwargs["additional_config"])
942
943
        vllm_group.add_argument('--structured-outputs-config',
                                **vllm_kwargs["structured_outputs_config"])
944

945
946
947
948
        # Other arguments
        parser.add_argument('--disable-log-stats',
                            action='store_true',
                            help='Disable logging statistics.')
949

950
        return parser
951
952

    @classmethod
953
    def from_cli_args(cls, args: argparse.Namespace):
954
955
956
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
957
958
959
960
        engine_args = cls(**{
            attr: getattr(args, attr)
            for attr in attrs if hasattr(args, attr)
        })
Zhuohan Li's avatar
Zhuohan Li committed
961
        return engine_args
962

963
    def create_model_config(self) -> ModelConfig:
964
965
966
967
968
969
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
970
                and self.model in MODELS_ON_S3 and self.load_format == "auto"):
971
972
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"

973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
        if self.disable_mm_preprocessor_cache:
            logger.warning(
                "`--disable-mm-preprocessor-cache` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-processor-cache-gb 0` instead.", )

            self.mm_processor_cache_gb = 0
        elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
            logger.warning(
                "VLLM_MM_INPUT_CACHE_GIB` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-processor-cache-gb %d` instead.",
                envs.VLLM_MM_INPUT_CACHE_GIB,
            )

            self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB

990
991
992
993
994
995
996
997
        if self.enable_multimodal_encoder_data_parallel:
            logger.warning(
                "--enable-multimodal-encoder-data-parallel` is deprecated "
                "and will be removed in v0.13. "
                "Please use `--mm-encoder-tp-mode data` instead.")

            self.mm_encoder_tp_mode = "data"

998
        return ModelConfig(
999
            model=self.model,
1000
            hf_config_path=self.hf_config_path,
1001
1002
            runner=self.runner,
            convert=self.convert,
1003
            task=self.task,
1004
            tokenizer=self.tokenizer,
1005
1006
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1007
            allowed_local_media_path=self.allowed_local_media_path,
1008
1009
1010
1011
1012
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1013
            rope_theta=self.rope_theta,
1014
            hf_token=self.hf_token,
1015
            hf_overrides=self.hf_overrides,
1016
1017
1018
1019
1020
1021
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
1022
            logprobs_mode=self.logprobs_mode,
1023
            disable_sliding_window=self.disable_sliding_window,
1024
            disable_cascade_attn=self.disable_cascade_attn,
1025
            skip_tokenizer_init=self.skip_tokenizer_init,
1026
            enable_prompt_embeds=self.enable_prompt_embeds,
1027
            served_model_name=self.served_model_name,
1028
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1029
            interleave_mm_strings=self.interleave_mm_strings,
1030
            media_io_kwargs=self.media_io_kwargs,
1031
            skip_mm_profiling=self.skip_mm_profiling,
1032
            use_async_output_proc=not self.disable_async_output_proc,
1033
            config_format=self.config_format,
1034
            mm_processor_kwargs=self.mm_processor_kwargs,
1035
            mm_processor_cache_gb=self.mm_processor_cache_gb,
1036
1037
1038
            mm_processor_cache_type=self.mm_processor_cache_type,
            mm_shm_cache_max_object_size_mb=self.
            mm_shm_cache_max_object_size_mb,
1039
            mm_encoder_tp_mode=self.mm_encoder_tp_mode,
1040
            pooler_config=self.pooler_config,
1041
            override_pooler_config=self.override_pooler_config,
1042
            logits_processor_pattern=self.logits_processor_pattern,
1043
            generation_config=self.generation_config,
1044
            override_generation_config=self.override_generation_config,
1045
            enable_sleep_mode=self.enable_sleep_mode,
1046
            model_impl=self.model_impl,
1047
            override_attention_dtype=self.override_attention_dtype,
1048
            logits_processors=self.logits_processors,
1049
            io_processor_plugin=self.io_processor_plugin,
1050
        )
1051

1052
1053
1054
1055
1056
1057
1058
    def validate_tensorizer_args(self):
        from vllm.model_executor.model_loader.tensorizer import (
            TensorizerConfig)
        for key in self.model_loader_extra_config:
            if key in TensorizerConfig._fields:
                self.model_loader_extra_config["tensorizer_config"][
                    key] = self.model_loader_extra_config[key]
1059

1060
1061
    def create_load_config(self) -> LoadConfig:

1062
1063
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1064

1065
1066
1067
1068
1069
1070
1071
1072
        if self.load_format == "tensorizer":
            if hasattr(self.model_loader_extra_config, "to_serializable"):
                self.model_loader_extra_config = (
                    self.model_loader_extra_config.to_serializable())
            self.model_loader_extra_config["tensorizer_config"] = {}
            self.model_loader_extra_config["tensorizer_config"][
                "tensorizer_dir"] = self.model
            self.validate_tensorizer_args()
1073

1074
1075
1076
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
1077
            safetensors_load_strategy=self.safetensors_load_strategy,
1078
1079
            device="cpu"
            if is_online_quantization(self.quantization) else None,
1080
1081
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1082
            use_tqdm_on_load=self.use_tqdm_on_load,
1083
            pt_load_map_location=self.pt_load_map_location,
1084
        )
1085

1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1099
        dictionary from the engine.
1100
        """
1101
1102
1103
1104
1105

        from vllm.transformers_utils.config import get_config
        from vllm.transformers_utils.configs.speculators.base import (
            SpeculatorsConfig)

1106
        if self.speculative_config is None:
1107
1108
1109
1110
            hf_config = get_config(
                self.hf_config_path or target_model_config.model,
                self.trust_remote_code, self.revision, self.code_revision,
                self.config_format)
1111

1112
            # if loading a SpeculatorsConfig, load the speculative_config
1113
1114
1115
            # details from the config directly
            # no user input required / expected
            if isinstance(hf_config, SpeculatorsConfig):
1116
                # We create one since we don't create one
1117
1118
1119
                self.speculative_config = {}
                self.speculative_config[
                    "num_speculative_tokens"] = hf_config.num_lookahead_tokens
1120
                self.speculative_config["model"] = target_model_config.model
1121
1122
1123
                self.speculative_config["method"] = hf_config.method
            else:
                return None
1124

1125
1126
1127
1128
1129
1130
1131
1132
1133
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
1134
        return SpeculativeConfig(**self.speculative_config)
1135

1136
1137
1138
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
1139
        headless: bool = False,
1140
1141
1142
1143
1144
1145
1146
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1147

1148
1149
1150
1151
1152
1153
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1154
        current_platform.pre_register_and_update()
1155

1156
1157
        device_config = DeviceConfig(
            device=cast(Device, current_platform.device_type))
1158
1159
        model_config = self.create_model_config()

1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
1179
            self._set_default_args_v1(usage_context, model_config)
1180
            # Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1
1181
1182
            if current_platform.is_cpu(
            ) and current_platform.get_cpu_architecture() in (
1183
                    CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM):
1184
                logger.info(
1185
1186
                    "Chunked prefill is not supported for ARM and POWER "
                    "and S390X CPUs; "
1187
1188
                    "disabling it for V1 backend.")
                self.enable_chunked_prefill = False
1189
1190
        else:
            self._set_default_args_v0(model_config)
1191
1192
        assert self.enable_chunked_prefill is not None

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
        if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
            assert self.enforce_eager, (
                "Cuda graph is not supported with DualChunkFlashAttention. "
                "To run the model in eager mode, set 'enforce_eager=True' "
                "or use '--enforce-eager' in the CLI.")
            assert current_platform.is_cuda(), (
                "DualChunkFlashAttention is only supported on CUDA platform.")
            assert not use_v1, (
                "DualChunkFlashAttention is not supported on V1 engine. "
                "To run the model in V0 engine, try set 'VLLM_USE_V1=0'")

1204
1205
1206
1207
1208
1209
1210
        sliding_window: Optional[int] = None
        if not is_interleaved(model_config.hf_text_config):
            # Only set CacheConfig.sliding_window if the model is all sliding
            # window. Otherwise CacheConfig.sliding_window will override the
            # global layers in interleaved sliding window models.
            sliding_window = model_config.get_sliding_window()

1211
1212
1213
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
1214
        # reuses the GPUs of TP group, and split one TP group into
1215
1216
1217
1218
1219
1220
1221
        # tp_size//dcp_size DCP groups.
        assert self.tensor_parallel_size % self.decode_context_parallel_size \
            == 0, (
            f"tp_size={self.tensor_parallel_size} must be divisible by"
            f"dcp_size={self.decode_context_parallel_size}."
        )

1222
        cache_config = CacheConfig(
1223
            block_size=self.block_size,
1224
            gpu_memory_utilization=self.gpu_memory_utilization,
1225
            kv_cache_memory_bytes=self.kv_cache_memory_bytes,
1226
1227
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1228
            is_attention_free=model_config.is_attention_free,
1229
            num_gpu_blocks_override=self.num_gpu_blocks_override,
1230
            sliding_window=sliding_window,
1231
            enable_prefix_caching=self.enable_prefix_caching,
1232
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1233
            cpu_offload_gb=self.cpu_offload_gb,
1234
            calculate_kv_scales=self.calculate_kv_scales,
1235
            kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
1236
1237
            mamba_cache_dtype=self.mamba_cache_dtype,
            mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1238
        )
1239

1240
1241
1242
1243
1244
1245
1246
1247
1248
        ray_runtime_env = None
        if is_ray_initialized():
            # Ray Serve LLM calls `create_engine_config` in the context
            # of a Ray task, therefore we check is_ray_initialized()
            # as opposed to is_in_ray_actor().
            import ray
            ray_runtime_env = ray.get_runtime_context().runtime_env
            logger.info("Using ray runtime env: %s", ray_runtime_env)

1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1260
1261
1262
1263
        assert not headless or not self.data_parallel_hybrid_lb, (
            "data_parallel_hybrid_lb is not applicable in "
            "headless mode")

1264
        data_parallel_external_lb = self.data_parallel_rank is not None
1265
        # Local DP rank = 1, use pure-external LB.
1266
1267
1268
1269
1270
        if data_parallel_external_lb:
            assert self.data_parallel_size_local in (1, None), (
                "data_parallel_size_local must be 1 when data_parallel_rank "
                "is set")
            data_parallel_size_local = 1
1271
1272
            # Use full external lb if we have local_size of 1.
            self.data_parallel_hybrid_lb = False
1273
1274
        elif self.data_parallel_size_local is not None:
            data_parallel_size_local = self.data_parallel_size_local
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289

            if self.data_parallel_start_rank and not headless:
                # Infer hybrid LB mode.
                self.data_parallel_hybrid_lb = True

            if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
                # Use full external lb if we have local_size of 1.
                data_parallel_external_lb = True
                self.data_parallel_hybrid_lb = False

            if data_parallel_size_local == self.data_parallel_size:
                # Disable hybrid LB mode if set for a single node
                self.data_parallel_hybrid_lb = False

            self.data_parallel_rank = self.data_parallel_start_rank or 0
1290
        else:
1291
1292
1293
1294
            assert not self.data_parallel_hybrid_lb, (
                "data_parallel_size_local must be set to use "
                "data_parallel_hybrid_lb.")

1295
1296
            # Local DP size defaults to global DP size if not set.
            data_parallel_size_local = self.data_parallel_size
1297
1298
1299

        # DP address, used in multi-node case for torch distributed group
        # and ZMQ sockets.
Rui Qiao's avatar
Rui Qiao committed
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
        if self.data_parallel_address is None:
            if self.data_parallel_backend == "ray":
                host_ip = get_ip()
                logger.info(
                    "Using host IP %s as ray-based data parallel address",
                    host_ip)
                data_parallel_address = host_ip
            else:
                assert self.data_parallel_backend == "mp", (
                    "data_parallel_backend can only be ray or mp, got %s",
                    self.data_parallel_backend)
                data_parallel_address = ParallelConfig.data_parallel_master_ip
        else:
            data_parallel_address = self.data_parallel_address
1314
1315
1316
1317
1318
1319
1320

        # This port is only used when there are remote data parallel engines,
        # otherwise the local IPC transport is used.
        data_parallel_rpc_port = self.data_parallel_rpc_port if (
            self.data_parallel_rpc_port
            is not None) else ParallelConfig.data_parallel_rpc_port

1321
1322
1323
1324
        if self.async_scheduling:
            # Async scheduling does not work with the uniprocess backend.
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "mp"
1325
1326
                logger.info("Defaulting to mp-based distributed executor "
                            "backend for async scheduling.")
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
            if self.pipeline_parallel_size > 1:
                raise ValueError("Async scheduling is not supported with "
                                 "pipeline-parallel-size > 1.")

            # Currently, async scheduling does not support speculative decoding.
            # TODO(woosuk): Support it.
            if self.speculative_config is not None:
                raise ValueError(
                    "Currently, speculative decoding is not supported with "
                    "async scheduling.")

1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
        # Forward the deprecated CLI args to the EPLB config.
        if self.num_redundant_experts is not None:
            self.eplb_config.num_redundant_experts = self.num_redundant_experts
        if self.eplb_window_size is not None:
            self.eplb_config.window_size = self.eplb_window_size
        if self.eplb_step_interval is not None:
            self.eplb_config.step_interval = self.eplb_step_interval
        if self.eplb_log_balancedness is not None:
            self.eplb_config.log_balancedness = self.eplb_log_balancedness

1348
        parallel_config = ParallelConfig(
1349
1350
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1351
            data_parallel_size=self.data_parallel_size,
1352
1353
            data_parallel_rank=self.data_parallel_rank or 0,
            data_parallel_external_lb=data_parallel_external_lb,
1354
1355
1356
            data_parallel_size_local=data_parallel_size_local,
            data_parallel_master_ip=data_parallel_address,
            data_parallel_rpc_port=data_parallel_rpc_port,
1357
            data_parallel_backend=self.data_parallel_backend,
1358
            data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
1359
            enable_expert_parallel=self.enable_expert_parallel,
1360
1361
            enable_dbo=self.enable_dbo,
            dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1362
            enable_eplb=self.enable_eplb,
1363
            eplb_config=self.eplb_config,
1364
            expert_placement_strategy=self.expert_placement_strategy,
1365
1366
1367
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1368
            ray_runtime_env=ray_runtime_env,
1369
            placement_group=placement_group,
1370
1371
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1372
            worker_extension_cls=self.worker_extension_cls,
1373
            decode_context_parallel_size=self.decode_context_parallel_size,
1374
1375
            _api_process_count=self._api_process_count,
            _api_process_rank=self._api_process_rank,
1376
        )
1377

1378
        speculative_config = self.create_speculative_config(
1379
1380
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1381
            enable_chunked_prefill=self.enable_chunked_prefill,
1382
            disable_log_stats=self.disable_log_stats,
1383
1384
        )

1385
1386
1387
1388
1389
        # make sure num_lookahead_slots is set appropriately depending on
        # whether speculative decoding is enabled
        num_lookahead_slots = self.num_lookahead_slots
        if speculative_config is not None:
            num_lookahead_slots = speculative_config.num_lookahead_slots
1390

1391
        scheduler_config = SchedulerConfig(
1392
            runner_type=model_config.runner_type,
1393
1394
1395
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1396
            cuda_graph_sizes=self.cuda_graph_sizes,
1397
            num_lookahead_slots=num_lookahead_slots,
1398
1399
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1400
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1401
            is_multimodal_model=model_config.is_multimodal_model,
1402
            preemption_mode=self.preemption_mode,
1403
1404
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1405
            policy=self.scheduling_policy,
1406
            scheduler_cls=self.scheduler_cls,
1407
1408
1409
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
1410
1411
            disable_hybrid_kv_cache_manager=self.
            disable_hybrid_kv_cache_manager,
1412
            async_scheduling=self.async_scheduling,
1413
        )
1414

1415
1416
1417
1418
1419
        if not model_config.is_multimodal_model and self.default_mm_loras:
            raise ValueError(
                "Default modality-specific LoRA(s) were provided for a "
                "non multimodal model")

1420
        lora_config = LoRAConfig(
1421
            bias_enabled=self.enable_lora_bias,
1422
1423
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1424
            default_mm_loras=self.default_mm_loras,
1425
            fully_sharded_loras=self.fully_sharded_loras,
1426
1427
1428
1429
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1430

1431
1432
1433
1434
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1435
        load_config = self.create_load_config()
1436

1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
        # Pass reasoning_parser into StructuredOutputsConfig
        if self.reasoning_parser:
            self.structured_outputs_config.reasoning_parser = \
                self.reasoning_parser

        # Forward the deprecated CLI args to the StructuredOutputsConfig
        so_config = self.structured_outputs_config
        if self.guided_decoding_backend is not None:
            so_config.guided_decoding_backend = \
            self.guided_decoding_backend
        if self.guided_decoding_disable_fallback is not None:
            so_config.guided_decoding_disable_fallback = \
            self.guided_decoding_disable_fallback
        if self.guided_decoding_disable_any_whitespace is not None:
            so_config.guided_decoding_disable_any_whitespace = \
            self.guided_decoding_disable_any_whitespace
        if self.guided_decoding_disable_additional_properties is not None:
            so_config.guided_decoding_disable_additional_properties = \
            self.guided_decoding_disable_additional_properties
1456

1457
        observability_config = ObservabilityConfig(
1458
1459
            show_hidden_metrics_for_version=(
                self.show_hidden_metrics_for_version),
1460
            otlp_traces_endpoint=self.otlp_traces_endpoint,
1461
            collect_detailed_traces=self.collect_detailed_traces,
1462
        )
1463

1464
        config = VllmConfig(
1465
1466
1467
1468
1469
1470
1471
1472
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
1473
            structured_outputs_config=self.structured_outputs_config,
1474
            observability_config=observability_config,
1475
            compilation_config=self.compilation_config,
1476
            kv_transfer_config=self.kv_transfer_config,
1477
            kv_events_config=self.kv_events_config,
1478
            additional_config=self.additional_config,
1479
        )
1480

1481
1482
        return config

1483
1484
1485
1486
1487
1488
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

1489
        if self.load_format == "sharded_state":
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

1501
        if self.preemption_mode != SchedulerConfig.preemption_mode:
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

1512
        if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
1513
1514
1515
1516
1517
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

        if self.kv_cache_dtype != "auto":
1518
            supported = current_platform.is_kv_cache_dtype_supported(
1519
                self.kv_cache_dtype, model_config)
1520
1521
1522
1523
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1524

1525
        # No Mamba or Encoder-Decoder so far.
1526
1527
1528
1529
1530
1531
1532
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
1533
                != SchedulerConfig.max_num_partial_prefills
1534
                or self.max_long_partial_prefills
1535
                != SchedulerConfig.max_long_partial_prefills):
1536
1537
1538
1539
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

1540
        # V1 supports N-gram, Medusa, and Eagle speculative decoding.
1541
1542
1543
1544
1545
1546
        if (self.speculative_config is not None
                and self.speculative_config.get("method") == "draft_model"):
            raise NotImplementedError(
                "Speculative decoding with draft model is not supported yet. "
                "Please consider using other speculative decoding methods "
                "such as ngram, medusa, eagle, or deepseek_mtp.")
1547
1548

        V1_BACKENDS = [
1549
1550
1551
1552
1553
1554
            "FLASH_ATTN_VLLM_V1",
            "FLASH_ATTN",
            "PALLAS",
            "PALLAS_VLLM_V1",
            "TRITON_ATTN_VLLM_V1",
            "TRITON_MLA",
1555
            "CUTLASS_MLA",
1556
            "FLASHMLA",
1557
1558
            "FLASHMLA_VLLM_V1",
            "FLASH_ATTN_MLA",
1559
1560
            "FLASHINFER",
            "FLASHINFER_VLLM_V1",
1561
            "FLASHINFER_MLA",
1562
            "ROCM_AITER_MLA",
1563
            "TORCH_SDPA_VLLM_V1",
1564
            "FLEX_ATTENTION",
1565
            "TREE_ATTN",
1566
            "XFORMERS_VLLM_V1",
1567
1568
1569
1570
1571
1572
1573
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1574
1575
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1576
1577
1578
1579
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1580
1581
1582
        #############################################################
        # Experimental Features - allow users to opt in.

1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
        if self.pipeline_parallel_size > 1:
            supports_pp = getattr(self.distributed_executor_backend,
                                  'supports_pp', False)
            if not supports_pp and self.distributed_executor_backend not in (
                    ParallelConfig.distributed_executor_backend, "ray", "mp",
                    "external_launcher"):
                name = "Pipeline Parallelism without Ray distributed " \
                        "executor or multiprocessing executor or external " \
                        "launcher"
                _raise_or_fallback(feature_name=name,
                                   recommend_to_remove=False)
                return False
1595

1596
1597
1598
1599
        # The platform may be supported on V1, but off by default for now.
        if not current_platform.default_v1(  # noqa: SIM103
                model_config=model_config) and _warn_or_fallback(
                    current_platform.device_name):
1600
            return False
1601
1602
1603
1604
1605
1606
1607

        if (current_platform.is_cpu()
                and model_config.get_sliding_window() is not None):
            _raise_or_fallback(feature_name="sliding window (CPU backend)",
                               recommend_to_remove=False)
            return False

1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1628
                use_spec_decode = self.speculative_config is not None
1629
1630

                if (is_gpu and not use_sliding_window and not use_spec_decode
1631
                        and not self.enable_lora):
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)

1650
1651
1652
1653
1654
1655
        # Disable prefix caching for multimodal models for VLLM_V0.
        if self.enable_prefix_caching and model_config.is_multimodal_model:
            logger.warning(
                "--enable-prefix-caching is not supported for multimodal "
                "models in V0 and has been disabled.")
            self.enable_prefix_caching = False
1656

1657
1658
1659
1660
1661
1662
1663
            if self.enable_prompt_embeds:
                logger.warning(
                    "--enable-prompt-embeds and --enable-prefix-caching "
                    "are not supported together in V0. Prefix caching has "
                    "been disabled.")
                self.enable_prefix_caching = False

1664
1665
1666
1667
        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

1668
1669
    def _set_default_args_v1(self, usage_context: UsageContext,
                             model_config: ModelConfig) -> None:
1670
        """Set Default Arguments for V1 Engine."""
1671

1672
1673
1674
1675
1676
        # V1 always uses chunked prefills and prefix caching
        # for non-pooling tasks.
        # For pooling tasks the default is False
        if model_config.runner_type != "pooling":
            self.enable_chunked_prefill = True
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687

            # TODO: When prefix caching supports prompt embeds inputs, this
            # check can be removed.
            if (self.enable_prompt_embeds
                    and self.enable_prefix_caching is not False):
                logger.warning(
                    "--enable-prompt-embeds and --enable-prefix-caching "
                    "are not supported together in V1. Prefix caching has "
                    "been disabled.")
                self.enable_prefix_caching = False

1688
1689
1690
1691
1692
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = True
        else:

            pooling_type = model_config.pooler_config.pooling_type
1693
1694
1695
1696
            is_causal = getattr(model_config.hf_config, "is_causal", True)
            incremental_prefill_supported = (pooling_type is not None
                                             and pooling_type.lower() == "last"
                                             and is_causal)
1697

1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
            action = "Enabling" if \
                incremental_prefill_supported else "Disabling"

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = incremental_prefill_supported
                logger.info("(%s) chunked prefill by default", action)
            if self.enable_prefix_caching is None:
                self.enable_prefix_caching = incremental_prefill_supported
                logger.info("(%s) prefix caching by default", action)

1708
1709
1710
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1711
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1712

1713
1714
        # When no user override, set the default values based on the usage
        # context.
1715
        # Use different default values for different hardware.
1716
1717
1718
1719
1720
1721
1722

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
1723
            device_memory = current_platform.get_device_total_memory()
1724
            device_name = current_platform.get_device_name().lower()
1725
1726
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1727
            device_memory = 0
1728

1729
1730
1731
        # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
        # throughput, see PR #17885 for more details.
        # So here we do an extra device name check to prevent such regression.
1732
        from vllm.usage.usage_lib import UsageContext
1733
        if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
1734
            # For GPUs like H100 and MI300x, use larger default values.
1735
1736
1737
1738
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1739
1740
1741
1742
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 1024,
                UsageContext.OPENAI_API_SERVER: 1024,
            }
1743
1744
1745
1746
1747
1748
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1749
1750
1751
1752
            default_max_num_seqs = {
                UsageContext.LLM_CLASS: 256,
                UsageContext.OPENAI_API_SERVER: 256,
            }
1753

1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
        # tpu specific default values.
        if current_platform.is_tpu():
            default_max_num_batched_tokens_tpu = {
                UsageContext.LLM_CLASS: {
                    'V6E': 2048,
                    'V5E': 1024,
                    'V5P': 512,
                },
                UsageContext.OPENAI_API_SERVER: {
                    'V6E': 1024,
                    'V5E': 512,
                    'V5P': 256,
                }
            }

1769
1770
        # cpu specific default values.
        if current_platform.is_cpu():
1771
            world_size = self.pipeline_parallel_size * self.tensor_parallel_size
1772
            default_max_num_batched_tokens = {
1773
1774
                UsageContext.LLM_CLASS: 4096 * world_size,
                UsageContext.OPENAI_API_SERVER: 2048 * world_size,
1775
1776
            }
            default_max_num_seqs = {
1777
1778
                UsageContext.LLM_CLASS: 256 * world_size,
                UsageContext.OPENAI_API_SERVER: 128 * world_size,
1779
1780
            }

1781
        use_context_value = usage_context.value if usage_context else None
1782
1783
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
            if current_platform.is_tpu():
                chip_name = current_platform.get_device_name()
                if chip_name in default_max_num_batched_tokens_tpu[
                        usage_context]:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens_tpu[
                            usage_context][chip_name]
                else:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens[usage_context]
            else:
1795
1796
1797
1798
1799
                if not self.enable_chunked_prefill:
                    self.max_num_batched_tokens = model_config.max_model_len
                else:
                    self.max_num_batched_tokens = \
                        default_max_num_batched_tokens[usage_context]
1800
            logger.debug(
1801
                "Setting max_num_batched_tokens to %d for %s usage context.",
1802
                self.max_num_batched_tokens, use_context_value)
1803

1804
1805
        if (self.max_num_seqs is None
                and usage_context in default_max_num_seqs):
1806
1807
            self.max_num_seqs = min(default_max_num_seqs[usage_context],
                                    self.max_num_batched_tokens or sys.maxsize)
1808
1809
1810

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1811

1812

1813
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1814
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1815
    """Arguments for asynchronous vLLM engine."""
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
    enable_log_requests: bool = False

    @property
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
        "`enable_log_requests` instead.")
    def disable_log_requests(self) -> bool:
        return not self.enable_log_requests

    @disable_log_requests.setter
    @deprecated(
        "`disable_log_requests` is deprecated and has been replaced with "
        "`enable_log_requests`. This will be removed in v0.12.0. Please use "
        "`enable_log_requests` instead.")
    def disable_log_requests(self, value: bool):
        self.enable_log_requests = not value
1833
1834

    @staticmethod
1835
1836
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1837
        # Initialize plugin to update the parser, for example, The plugin may
1838
        # add a new kind of quantization method to --quantization argument or
1839
1840
        # a new device to --device argument.
        load_general_plugins()
1841
1842
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1843
1844
1845
1846
        parser.add_argument('--enable-log-requests',
                            action=argparse.BooleanOptionalAction,
                            default=AsyncEngineArgs.enable_log_requests,
                            help='Enable logging requests.')
1847
        parser.add_argument('--disable-log-requests',
1848
1849
1850
1851
                            action=argparse.BooleanOptionalAction,
                            default=not AsyncEngineArgs.enable_log_requests,
                            help='[DEPRECATED] Disable logging requests.',
                            deprecated=True)
1852
        current_platform.pre_register_and_update(parser)
1853
        return parser
1854
1855


1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1883
1884
1885
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
1886

1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
    match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
    if match:
        decimal_multiplier = {
            'k': 10**3,
            'm': 10**6,
            'g': 10**9,
        }
        binary_multiplier = {
            'K': 2**10,
            'M': 2**20,
            'G': 2**30,
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
                raise argparse.ArgumentTypeError("Decimals are not allowed " \
                f"with binary suffixes like {suffix}. Did you mean to use " \
                f"{number}{suffix.lower()} instead?") from e

    # Regular plain number.
    return int(value)