arg_utils.py 77.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
# yapf: disable
4
import argparse
5
import dataclasses
6
import json
7
import re
8
import threading
9
from dataclasses import MISSING, dataclass, fields
10
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
11
                    TypeVar, Union, cast, get_args, get_origin)
12

13
import torch
14
from typing_extensions import TypeIs, deprecated
15

16
import vllm.envs as envs
17
from vllm import version
18
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
19
                         ConfigFormat, ConfigType, DecodingConfig, Device,
20
21
                         DeviceConfig, DistributedExecutorBackend,
                         GuidedDecodingBackendV1, HfOverrides,
22
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
23
24
                         ModelConfig, ModelImpl, MultiModalConfig,
                         ObservabilityConfig, ParallelConfig, PoolerConfig,
25
                         PrefixCachingHashAlgo, PromptAdapterConfig,
26
27
28
                         SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
                         TaskOption, TokenizerPoolConfig, VllmConfig,
                         get_attr_docs, get_field)
29
from vllm.executor.executor_base import ExecutorBase
30
from vllm.logger import init_logger
31
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
32
from vllm.plugins import load_general_plugins
33
from vllm.reasoning import ReasoningParserManager
34
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
35
from vllm.transformers_utils.utils import check_gguf_file
36
from vllm.usage.usage_lib import UsageContext
37
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
38

39
# yapf: enable
40

41
42
logger = init_logger(__name__)

43
44
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

45
46
47
48
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
49

50

51
52
def optional_type(
        return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
53

54
55
56
57
58
59
60
61
62
63
    def _optional_type(val: str) -> Optional[T]:
        if val == "" or val == "None":
            return None
        try:
            if return_type is json.loads and not re.match("^{.*}$", val):
                return cast(T, nullable_kvs(val))
            return return_type(val)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
                f"Value {val} cannot be converted to {return_type}.") from e
64

65
    return _optional_type
66
67


68
69
70
71
72
@deprecated(
    "Passing a JSON argument as a string containing comma separated key=value "
    "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
    "string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
73
74
75
76
77
78
79
80
81
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
82
    out_dict: dict[str, int] = {}
83
    for item in val.split(","):
84
85
86
87
88
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
89
90

        try:
91
            parsed_value = int(value)
92
93
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
94
95
96
97
98
99
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
100
101
102
103

    return out_dict


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
    """Check if the type hint is a specific type."""
    return type_hint is type or get_origin(type_hint) is type


def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
    """Check if the type hints contain a specific type."""
    return any(is_type(type_hint, type) for type_hint in type_hints)


def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
    """Get the specific type from the type hints."""
    return next((th for th in type_hints if is_type(th, type)), None)


def is_not_builtin(type_hint: TypeHint) -> bool:
    """Check if the class is not a built-in type."""
    return type_hint.__module__ != "builtins"


def get_kwargs(cls: ConfigType) -> dict[str, Any]:
    cls_docs = get_attr_docs(cls)
    kwargs = {}
    for field in fields(cls):
        # Get the default value of the field
        default = field.default
        if field.default_factory is not MISSING:
            default = field.default_factory()

        # Get the help text for the field
        name = field.name
        help = cls_docs[name]
        # Escape % for argparse
        help = help.replace("%", "%%")

        # Initialise the kwargs dictionary for the field
        kwargs[name] = {"default": default, "help": help}

        # Get the set of possible types for the field
        type_hints: set[TypeHint] = set()
        if get_origin(field.type) is Union:
            type_hints.update(get_args(field.type))
        else:
            type_hints.add(field.type)

        # Set other kwargs based on the type hints
        if contains_type(type_hints, bool):
            # Creates --no-<name> and --<name> flags
            kwargs[name]["action"] = argparse.BooleanOptionalAction
        elif contains_type(type_hints, Literal):
            # Creates choices from Literal arguments
            type_hint = get_type(type_hints, Literal)
            choices = sorted(get_args(type_hint))
            kwargs[name]["choices"] = choices
            choice_type = type(choices[0])
            assert all(type(c) is choice_type for c in choices), (
                "All choices must be of the same type. "
                f"Got {choices} with types {[type(c) for c in choices]}")
            kwargs[name]["type"] = choice_type
        elif contains_type(type_hints, tuple):
            type_hint = get_type(type_hints, tuple)
            types = get_args(type_hint)
            tuple_type = types[0]
            assert all(t is tuple_type for t in types if t is not Ellipsis), (
                "All non-Ellipsis tuple elements must be of the same "
                f"type. Got {types}.")
            kwargs[name]["type"] = tuple_type
            kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
        elif contains_type(type_hints, list):
            type_hint = get_type(type_hints, list)
            types = get_args(type_hint)
            assert len(types) == 1, (
                "List type must have exactly one type. Got "
                f"{type_hint} with types {types}")
            kwargs[name]["type"] = types[0]
            kwargs[name]["nargs"] = "+"
        elif contains_type(type_hints, int):
            kwargs[name]["type"] = int
        elif contains_type(type_hints, float):
            kwargs[name]["type"] = float
        elif contains_type(type_hints, dict):
            # Dict arguments will always be optional
            kwargs[name]["type"] = optional_type(json.loads)
        elif (contains_type(type_hints, str)
              or any(is_not_builtin(th) for th in type_hints)):
            kwargs[name]["type"] = str
        else:
            raise ValueError(
                f"Unsupported type {type_hints} for argument {name}.")

        # If None is in type_hints, make the argument optional.
        # But not if it's a bool, argparse will handle this better.
        if type(None) in type_hints and not contains_type(type_hints, bool):
            kwargs[name]["type"] = optional_type(kwargs[name]["type"])
            if kwargs[name].get("choices"):
                kwargs[name]["choices"].append("None")
    return kwargs
201
202


203
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
204
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
205
    """Arguments for vLLM engine."""
206
    model: str = 'facebook/opt-125m'
207
    served_model_name: Optional[Union[str, List[str]]] = None
208
    tokenizer: Optional[str] = None
209
    hf_config_path: Optional[str] = None
210
    task: TaskOption = "auto"
211
    skip_tokenizer_init: bool = False
212
    tokenizer_mode: str = 'auto'
213
    trust_remote_code: bool = False
214
    allowed_local_media_path: str = ""
215
216
    download_dir: Optional[str] = LoadConfig.download_dir
    load_format: str = LoadConfig.load_format
217
    config_format: ConfigFormat = ConfigFormat.AUTO
218
    dtype: str = 'auto'
219
    kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
220
    seed: Optional[int] = None
221
    max_model_len: Optional[int] = None
222
223
224
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
225
    distributed_executor_backend: Optional[Union[
226
227
        DistributedExecutorBackend,
        Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
228
    # number of P/D disaggregation (or other disaggregation) workers
229
230
231
232
233
234
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
    data_parallel_size: int = ParallelConfig.data_parallel_size
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
    max_parallel_loading_workers: Optional[
        int] = ParallelConfig.max_parallel_loading_workers
235
236
237
238
    block_size: Optional[BlockSize] = CacheConfig.block_size
    enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
    prefix_caching_hash_algo: PrefixCachingHashAlgo = \
        CacheConfig.prefix_caching_hash_algo
239
    disable_sliding_window: bool = False
240
    disable_cascade_attn: bool = False
241
    use_v2_block_manager: bool = True
242
243
244
    swap_space: float = CacheConfig.swap_space
    cpu_offload_gb: float = CacheConfig.cpu_offload_gb
    gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
245
246
247
248
249
250
251
    max_num_batched_tokens: Optional[
        int] = SchedulerConfig.max_num_batched_tokens
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
    long_prefill_token_threshold: int = \
        SchedulerConfig.long_prefill_token_threshold
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
252
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
253
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
254
    revision: Optional[str] = None
255
    code_revision: Optional[str] = None
256
    rope_scaling: Optional[Dict[str, Any]] = None
257
    rope_theta: Optional[float] = None
258
    hf_token: Optional[Union[bool, str]] = None
259
    hf_overrides: Optional[HfOverrides] = None
260
    tokenizer_revision: Optional[str] = None
261
    quantization: Optional[str] = None
262
    enforce_eager: Optional[bool] = None
263
    max_seq_len_to_capture: int = 8192
264
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
265
266
267
    # The following three fields are deprecated and will be removed in a future
    # release. Setting them will have no effect. Please remove them from your
    # configurations.
268
    tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
269
270
    tokenizer_pool_type: str = TokenizerPoolConfig.pool_type
    tokenizer_pool_extra_config: dict = \
271
        get_field(TokenizerPoolConfig, "extra_config")
272
    limit_mm_per_prompt: dict[str, int] = \
273
        get_field(MultiModalConfig, "limit_per_prompt")
274
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
275
    disable_mm_preprocessor_cache: bool = False
276
    # LoRA fields
277
    enable_lora: bool = False
278
279
280
281
282
    enable_lora_bias: bool = LoRAConfig.bias_enabled
    max_loras: int = LoRAConfig.max_loras
    max_lora_rank: int = LoRAConfig.max_lora_rank
    fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
    max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
zhuwenwen's avatar
zhuwenwen committed
283
    merge_lora: bool = LoRAConfig.merge_lora
zhuwenwen's avatar
zhuwenwen committed
284
    lora_target_modules: Optional[List[str]] = LoRAConfig.lora_target_modules
285
286
287
288
289
    lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
    lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
    long_lora_scaling_factors: Optional[tuple[float, ...]] = \
        LoRAConfig.long_lora_scaling_factors
    # PromptAdapter fields
290
    enable_prompt_adapter: bool = False
291
292
293
    max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
    max_prompt_adapter_token: int = \
        PromptAdapterConfig.max_prompt_adapter_token
294
    device: Device = DeviceConfig.device
295
296
    num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
    multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
297
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
298
299
    num_gpu_blocks_override: Optional[
        int] = CacheConfig.num_gpu_blocks_override
300
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
301
302
    model_loader_extra_config: dict = \
        get_field(LoadConfig, "model_loader_extra_config")
303
304
    ignore_patterns: Optional[Union[str,
                                    List[str]]] = LoadConfig.ignore_patterns
305
    preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
306

307
308
309
310
    scheduler_delay_factor: float = SchedulerConfig.delay_factor
    enable_chunked_prefill: Optional[
        bool] = SchedulerConfig.enable_chunked_prefill
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
311

312
    guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
313
    logits_processor_pattern: Optional[str] = None
314

315
    speculative_config: Optional[Dict[str, Any]] = None
zhuwenwen's avatar
zhuwenwen committed
316
    num_speculative_heads: Optional[int] = None
317

318
    qlora_adapter_name_or_path: Optional[str] = None
319
    show_hidden_metrics_for_version: Optional[str] = None
320
    otlp_traces_endpoint: Optional[str] = None
321
    collect_detailed_traces: Optional[str] = None
322
    disable_async_output_proc: bool = False
323
324
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
325

326
    override_neuron_config: Optional[Dict[str, Any]] = None
327
    override_pooler_config: Optional[PoolerConfig] = None
328
    compilation_config: Optional[CompilationConfig] = None
329
330
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
331

332
    kv_transfer_config: Optional[KVTransferConfig] = None
333

334
    generation_config: Optional[str] = "auto"
335
    override_generation_config: Optional[Dict[str, Any]] = None
336
    enable_sleep_mode: bool = False
337
    model_impl: str = "auto"
338

339
    calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
340

341
    additional_config: Optional[Dict[str, Any]] = None
342
    enable_reasoning: Optional[bool] = None
343
    reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend
344
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
王敏's avatar
王敏 committed
345

王敏's avatar
王敏 committed
346

347
    def __post_init__(self):
348
        if not self.tokenizer:
349
            self.tokenizer = self.model
350

351
352
353
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
354
        if isinstance(self.compilation_config, (int, dict)):
355
356
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
357

358
        # Setup plugins
359
360
        from vllm.plugins import load_general_plugins
        load_general_plugins()
361
362

    @staticmethod
363
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
364
        """Shared CLI arguments for vLLM engine."""
365

366
        # Model arguments
367
368
369
        parser.add_argument(
            '--model',
            type=str,
370
            default=EngineArgs.model,
371
            help='Name or path of the huggingface model to use.')
372
373
374
375
376
377
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
378
            'multiple tasks. When the model only supports one task, ``"auto"`` '
379
380
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
381
382
        parser.add_argument(
            '--tokenizer',
383
            type=optional_type(str),
384
            default=EngineArgs.tokenizer,
385
386
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
387
388
        parser.add_argument(
            "--hf-config-path",
389
            type=optional_type(str),
390
391
392
            default=EngineArgs.hf_config_path,
            help='Name or path of the huggingface config to use. '
            'If unspecified, model name or path will be used.')
393
394
395
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
396
397
398
            help='Skip initialization of tokenizer and detokenizer. '
            'Expects valid prompt_token_ids and None for prompt from '
            'the input. The generated output will contain token ids.')
Jasmond L's avatar
Jasmond L committed
399
400
        parser.add_argument(
            '--revision',
401
            type=optional_type(str),
Jasmond L's avatar
Jasmond L committed
402
            default=None,
403
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
404
405
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
406
407
        parser.add_argument(
            '--code-revision',
408
            type=optional_type(str),
409
            default=None,
410
            help='The specific revision to use for the model code on '
411
412
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
413
414
        parser.add_argument(
            '--tokenizer-revision',
415
            type=optional_type(str),
416
            default=None,
417
418
419
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
420
421
422
423
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
424
            choices=['auto', 'cpm', 'slow', 'mistral', 'custom'],
425
426
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
427
            'always use the slow tokenizer. \n* '
428
429
430
            '"mistral" will always use the `mistral_common` tokenizer. \n* '
            '"custom" will use --tokenizer to select the '
            'preregistered tokenizer.')
431
432
        parser.add_argument('--trust-remote-code',
                            action='store_true',
433
                            help='Trust remote code from huggingface.')
434
435
436
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
437
438
439
440
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
        load_group.add_argument('--load-format',
                                choices=[f.value for f in LoadFormat],
                                **load_kwargs["load_format"])
        load_group.add_argument('--download-dir',
                                **load_kwargs["download_dir"])
        load_group.add_argument('--model-loader-extra-config',
                                **load_kwargs["model_loader_extra_config"])
        load_group.add_argument('--use-tqdm-on-load',
                                **load_kwargs["use_tqdm_on_load"])

457
458
459
460
461
462
463
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
464
465
466
467
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
468
469
470
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
471
472
473
474
475
476
477
478
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
479
        parser.add_argument('--max-model-len',
480
                            type=human_readable_int,
481
                            default=EngineArgs.max_model_len,
482
                            help='Model context length. If unspecified, will '
483
484
485
486
487
                            'be automatically derived from the model config. '
                            'Supports k/m/g/K/M/G in human-readable format.\n'
                            'Examples:\n'
                            '- 1k → 1000\n'
                            '- 1K → 1024\n')
488
489
490
491
492
493
494
495

        # Guided decoding arguments
        guided_decoding_kwargs = get_kwargs(DecodingConfig)
        guided_decoding_group = parser.add_argument_group(
            title="DecodingConfig",
            description=DecodingConfig.__doc__,
        )
        guided_decoding_group.add_argument(
496
            '--guided-decoding-backend',
497
498
499
500
501
502
503
            **guided_decoding_kwargs["guided_decoding_backend"])
        guided_decoding_group.add_argument(
            "--reasoning-parser",
            # This choices is a special case because it's not static
            choices=list(ReasoningParserManager.reasoning_parsers),
            **guided_decoding_kwargs["reasoning_backend"])

504
505
        parser.add_argument(
            '--logits-processor-pattern',
506
            type=optional_type(str),
507
508
509
510
511
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
512
513
514
515
516
517
518
519
520
521
522
523
        parser.add_argument(
            '--model-impl',
            type=str,
            default=EngineArgs.model_impl,
            choices=[f.value for f in ModelImpl],
            help='Which implementation of the model to use.\n\n'
            '* "auto" will try to use the vLLM implementation if it exists '
            'and fall back to the Transformers implementation if no vLLM '
            'implementation is available.\n'
            '* "vllm" will use the vLLM model implementation.\n'
            '* "transformers" will use the Transformers model '
            'implementation.\n')
524
        # Parallel arguments
525
526
527
528
529
530
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
531
            '--distributed-executor-backend',
532
533
534
535
536
537
538
539
540
            **parallel_kwargs["distributed_executor_backend"])
        parallel_group.add_argument(
            '--pipeline-parallel-size', '-pp',
            **parallel_kwargs["pipeline_parallel_size"])
        parallel_group.add_argument('--tensor-parallel-size', '-tp',
                                    **parallel_kwargs["tensor_parallel_size"])
        parallel_group.add_argument('--data-parallel-size', '-dp',
                                    **parallel_kwargs["data_parallel_size"])
        parallel_group.add_argument(
541
            '--enable-expert-parallel',
542
543
            **parallel_kwargs["enable_expert_parallel"])
        parallel_group.add_argument(
544
            '--max-parallel-loading-workers',
545
546
            **parallel_kwargs["max_parallel_loading_workers"])
        parallel_group.add_argument(
547
            '--ray-workers-use-nsight',
548
549
550
551
            **parallel_kwargs["ray_workers_use_nsight"])
        parallel_group.add_argument(
            '--disable-custom-all-reduce',
            **parallel_kwargs["disable_custom_all_reduce"])
552

553
554
555
556
557
        # KV cache arguments
        cache_kwargs = get_kwargs(CacheConfig)
        cache_group = parser.add_argument_group(
            title="CacheConfig",
            description=CacheConfig.__doc__,
558
        )
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
        cache_group.add_argument('--block-size', **cache_kwargs["block_size"])
        cache_group.add_argument('--gpu-memory-utilization',
                                 **cache_kwargs["gpu_memory_utilization"])
        cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"])
        cache_group.add_argument('--kv-cache-dtype',
                                 **cache_kwargs["cache_dtype"])
        cache_group.add_argument('--num-gpu-blocks-override',
                                 **cache_kwargs["num_gpu_blocks_override"])
        cache_group.add_argument("--enable-prefix-caching",
                                 **cache_kwargs["enable_prefix_caching"])
        cache_group.add_argument("--prefix-caching-hash-algo",
                                 **cache_kwargs["prefix_caching_hash_algo"])
        cache_group.add_argument('--cpu-offload-gb',
                                 **cache_kwargs["cpu_offload_gb"])
        cache_group.add_argument('--calculate-kv-scales',
                                 **cache_kwargs["calculate_kv_scales"])

576
577
578
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
579
                            'capping to sliding window size.')
580
581
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
582
                            default=True,
583
584
585
586
587
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
588

589
590
591
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
592
                            help='Random seed for operations.')
593
594
595
596
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
597
598
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
599
600
        parser.add_argument('--disable-log-stats',
                            action='store_true',
601
                            help='Disable logging statistics.')
602
603
604
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
605
                            type=optional_type(str),
606
                            choices=[*QUANTIZATION_METHODS, None],
607
                            default=EngineArgs.quantization,
608
609
610
611
612
613
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
614
615
616
617
618
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
619
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
620
621
622
623
624
625
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
626
627
628
629
630
631
632
633
634
635
        parser.add_argument(
            '--hf-token',
            type=str,
            nargs='?',
            const=True,
            default=None,
            help='The token to use as HTTP bearer authorization'
            ' for remote files. If `True`, will use the token '
            'generated when running `huggingface-cli login` '
            '(stored in `~/.huggingface`).')
636
637
638
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
639
                            help='Extra arguments for the HuggingFace config. '
640
641
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
642
        parser.add_argument('--enforce-eager',
zhuwenwen's avatar
zhuwenwen committed
643
                            action='store_true',
644
645
646
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
647
        parser.add_argument('--max-seq-len-to-capture',
648
649
650
651
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
652
653
654
655
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
656
657
658
659
660
661
662
663
664
665
666
667
668

        # Tokenizer arguments
        tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
        tokenizer_group = parser.add_argument_group(
            title="TokenizerPoolConfig",
            description=TokenizerPoolConfig.__doc__,
        )
        tokenizer_group.add_argument('--tokenizer-pool-size',
                                     **tokenizer_kwargs["pool_size"])
        tokenizer_group.add_argument('--tokenizer-pool-type',
                                     **tokenizer_kwargs["pool_type"])
        tokenizer_group.add_argument('--tokenizer-pool-extra-config',
                                     **tokenizer_kwargs["extra_config"])
669
670

        # Multimodal related configs
671
672
673
674
675
676
677
678
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
        multimodal_group.add_argument('--limit-mm-per-prompt',
                                      **multimodal_kwargs["limit_per_prompt"])

679
680
681
682
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
683
684
685
686
            help=('Overrides for the multi-modal processor obtained from '
                  '``AutoProcessor.from_pretrained``. The available overrides '
                  'depend on the model that is being run.'
                  'For example, for Phi-3-Vision: ``{"num_crops": 4}``.'))
687
        parser.add_argument(
688
            '--disable-mm-preprocessor-cache',
689
            action='store_true',
690
691
            help='If True, disable caching of the processed multi-modal '
            'inputs.')
692

693
        # LoRA related configs
694
695
696
697
698
699
700
701
702
703
704
705
706
707
        lora_kwargs = get_kwargs(LoRAConfig)
        lora_group = parser.add_argument_group(
            title="LoRAConfig",
            description=LoRAConfig.__doc__,
        )
        lora_group.add_argument(
            '--enable-lora',
            action=argparse.BooleanOptionalAction,
            help='If True, enable handling of LoRA adapters.')
        lora_group.add_argument('--enable-lora-bias',
                                **lora_kwargs["bias_enabled"])
        lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
        lora_group.add_argument('--max-lora-rank',
                                **lora_kwargs["max_lora_rank"])
zhuwenwen's avatar
zhuwenwen committed
708
        lora_group.add_argument('--merge-lora',
zhuwenwen's avatar
zhuwenwen committed
709
710
711
                                **lora_kwargs["merge-lora"])
                            # action=argparse.BooleanOptionalAction,
                            # help='If set to True, the weights of the base layer will be merged with the weights of Lora.')
zhuwenwen's avatar
zhuwenwen committed
712
713
        lora_group.add_argument('--lora-target-modules',
                            **lora_kwargs["lora_target_modules"])
714
715
716
        lora_group.add_argument('--lora-extra-vocab-size',
                                **lora_kwargs["lora_extra_vocab_size"])
        lora_group.add_argument(
717
            '--lora-dtype',
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
            **lora_kwargs["lora_dtype"],
        )
        lora_group.add_argument('--long-lora-scaling-factors',
                                **lora_kwargs["long_lora_scaling_factors"])
        lora_group.add_argument('--max-cpu-loras',
                                **lora_kwargs["max_cpu_loras"])
        lora_group.add_argument('--fully-sharded-loras',
                                **lora_kwargs["fully_sharded_loras"])

        # PromptAdapter related configs
        prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
        prompt_adapter_group = parser.add_argument_group(
            title="PromptAdapterConfig",
            description=PromptAdapterConfig.__doc__,
        )
        prompt_adapter_group.add_argument(
            '--enable-prompt-adapter',
            action=argparse.BooleanOptionalAction,
            help='If True, enable handling of PromptAdapters.')
        prompt_adapter_group.add_argument(
            '--max-prompt-adapters',
            **prompt_adapter_kwargs["max_prompt_adapters"])
        prompt_adapter_group.add_argument(
            '--max-prompt-adapter-token',
            **prompt_adapter_kwargs["max_prompt_adapter_token"])
743
744
745
746
747
748
749
750
751

        # Device arguments
        device_kwargs = get_kwargs(DeviceConfig)
        device_group = parser.add_argument_group(
            title="DeviceConfig",
            description=DeviceConfig.__doc__,
        )
        device_group.add_argument("--device", **device_kwargs["device"])

752
753
754
755
756
757
758
759
760
761
762
        # Speculative arguments
        speculative_group = parser.add_argument_group(
            title="SpeculativeConfig",
            description=SpeculativeConfig.__doc__,
        )
        speculative_group.add_argument(
            '--speculative-config',
            type=json.loads,
            default=None,
            help='The configurations for speculative decoding.'
            ' Should be a JSON string.')
763

zhuwenwen's avatar
zhuwenwen committed
764
765
766
767
768
769
        parser.add_argument(
            '--num-speculative-heads',
            type=int,
            default=EngineArgs.num_speculative_heads,
            help='The number of speculative heads to sample from '
                 'the draft model in speculative decoding.')
770
771
772
773
774
775
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
776
            "Default to `original/**/*` to avoid repeated loading of llama's "
777
            "checkpoints.")
778

779
780
781
782
783
784
785
786
787
788
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
789
            "same as the ``--model`` argument. Noted that this name(s) "
790
            "will also be used in `model_name` tag content of "
791
            "prometheus metrics, if multiple names provided, metrics "
792
            "tag will take the first one.")
793
794
795
796
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
797

798
799
800
801
802
803
804
805
806
807
808
809
        parser.add_argument('--show-hidden-metrics-for-version',
                            type=str,
                            default=None,
                            help='Enable deprecated Prometheus metrics that '
                            'have been hidden since the specified version. '
                            'For example, if a previously deprecated metric '
                            'has been hidden since the v0.7.0 release, you '
                            'use --show-hidden-metrics-for-version=0.7 as a '
                            'temporary escape hatch while you migrate to new '
                            'metrics. The metric is likely to be removed '
                            'completely in an upcoming release.')

810
811
812
813
814
        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
815
816
817
818
819
820
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
821
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
822
823
824
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
825

826
827
828
829
830
831
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
832

833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
            '--max-num-batched-tokens',
            **scheduler_kwargs["max_num_batched_tokens"])
        scheduler_group.add_argument('--max-num-seqs',
                                     **scheduler_kwargs["max_num_seqs"])
        scheduler_group.add_argument(
            "--max-num-partial-prefills",
            **scheduler_kwargs["max_num_partial_prefills"])
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
            **scheduler_kwargs["max_long_partial_prefills"])
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
            **scheduler_kwargs["long_prefill_token_threshold"])
        scheduler_group.add_argument('--num-lookahead-slots',
                                     **scheduler_kwargs["num_lookahead_slots"])
        scheduler_group.add_argument('--scheduler-delay-factor',
                                     **scheduler_kwargs["delay_factor"])
857
858
859
860
        scheduler_group.add_argument('--preemption-mode',
                                     **scheduler_kwargs["preemption_mode"])
        scheduler_group.add_argument('--num-scheduler-steps',
                                     **scheduler_kwargs["num_scheduler_steps"])
861
862
863
864
865
        scheduler_group.add_argument(
            '--multi-step-stream-outputs',
            **scheduler_kwargs["multi_step_stream_outputs"])
        scheduler_group.add_argument('--scheduling-policy',
                                     **scheduler_kwargs["policy"])
866
867
868
        scheduler_group.add_argument(
            '--enable-chunked-prefill',
            **scheduler_kwargs["enable_chunked_prefill"])
869
870
871
872
873
        scheduler_group.add_argument(
            "--disable-chunked-mm-input",
            **scheduler_kwargs["disable_chunked_mm_input"])
        parser.add_argument('--scheduler-cls',
                            **scheduler_kwargs["scheduler_cls"])
874

875
876
        parser.add_argument(
            '--override-neuron-config',
877
            type=json.loads,
878
            default=None,
879
            help="Override or set neuron device configuration. "
880
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
881
        parser.add_argument(
882
883
            '--override-pooler-config',
            type=PoolerConfig.from_json,
884
            default=None,
885
            help="Override or set the pooling method for pooling models. "
886
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
887

888
889
890
891
892
893
894
895
896
897
898
899
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
900
901
                            'use a JSON string, e.g. ``{"level": 3, '
                            '"cudagraph_capture_sizes": [1, 2, 4, 8]}``\n'
902
                            'Following the convention of traditional '
903
904
                            'compilers, using ``-O`` without space is also '
                            'supported. ``-O3`` is equivalent to ``-O 3``.')
905

906
907
908
909
910
911
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

912
913
914
915
916
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
917
918
919
920
921
922
923
        parser.add_argument(
            '--worker-extension-cls',
            type=str,
            default="",
            help='The worker extension class on top of the worker cls, '
            'it is useful if you just want to add new functions to the worker '
            'class without changing the existing functions.')
924
925
        parser.add_argument(
            "--generation-config",
926
            type=optional_type(str),
927
            default="auto",
928
            help="The folder path to the generation config. "
929
930
931
932
933
            "Defaults to 'auto', the generation config will be loaded from "
            "model path. If set to 'vllm', no generation config is loaded, "
            "vLLM defaults will be used. If set to a folder path, the "
            "generation config will be loaded from the specified folder path. "
            "If `max_new_tokens` is specified in generation config, then "
934
935
936
937
938
939
940
941
942
943
944
945
            "it sets a server-wide limit on the number of output tokens "
            "for all requests.")

        parser.add_argument(
            "--override-generation-config",
            type=json.loads,
            default=None,
            help="Overrides or sets generation config in JSON format. "
            "e.g. ``{\"temperature\": 0.5}``. If used with "
            "--generation-config=auto, the override parameters will be merged "
            "with the default config from the model. If generation-config is "
            "None, only the override parameters are used.")
946

947
948
949
950
951
952
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

953
954
955
956
957
958
959
960
        parser.add_argument(
            "--additional-config",
            type=json.loads,
            default=None,
            help="Additional config for specified platform in JSON format. "
            "Different platforms may support different configs. Make sure the "
            "configs are valid for the platform you are using. The input format"
            " is like '{\"config_key\":\"config_value\"}'")
961
962
963
964
965
966
967
968
969

        parser.add_argument(
            "--enable-reasoning",
            action="store_true",
            default=False,
            help="Whether to enable reasoning_content for the model. "
            "If enabled, the model will be able to generate reasoning content."
        )

970
971
972
973
974
975
976
977
978
979
        parser.add_argument(
            "--disable-cascade-attn",
            action="store_true",
            default=False,
            help="Disable cascade attention for V1. While cascade attention "
            "does not change the mathematical correctness, disabling it "
            "could be useful for preventing potential numerical issues. "
            "Note that even if this is set to False, cascade attention will be "
            "only used when the heuristic tells that it's beneficial.")

980
        return parser
981
982

    @classmethod
983
    def from_cli_args(cls, args: argparse.Namespace):
984
985
986
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
987
988
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
989

990
    def create_model_config(self) -> ModelConfig:
991
992
993
994
995
996
997
998
999
1000
1001
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

1002
        return ModelConfig(
1003
            model=self.model,
1004
            hf_config_path=self.hf_config_path,
1005
            task=self.task,
1006
1007
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
1008
1009
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1010
            allowed_local_media_path=self.allowed_local_media_path,
1011
1012
1013
1014
1015
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1016
            rope_theta=self.rope_theta,
1017
            hf_token=self.hf_token,
1018
            hf_overrides=self.hf_overrides,
1019
1020
1021
1022
1023
1024
1025
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
1026
            disable_cascade_attn=self.disable_cascade_attn,
1027
            skip_tokenizer_init=self.skip_tokenizer_init,
1028
            served_model_name=self.served_model_name,
1029
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1030
            use_async_output_proc=not self.disable_async_output_proc,
1031
            config_format=self.config_format,
1032
            mm_processor_kwargs=self.mm_processor_kwargs,
1033
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1034
1035
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1036
            logits_processor_pattern=self.logits_processor_pattern,
1037
            generation_config=self.generation_config,
1038
            override_generation_config=self.override_generation_config,
1039
            enable_sleep_mode=self.enable_sleep_mode,
1040
            model_impl=self.model_impl,
1041
        )
1042

1043
1044
    def create_load_config(self) -> LoadConfig:

1045
        if(self.qlora_adapter_name_or_path is not None) and \
1046
1047
            self.quantization != "bitsandbytes":
            raise ValueError(
1048
                "QLoRA adapter only support "
1049
1050
                f"'bitsandbytes' quantization, but got {self.quantization}")

1051
1052
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1053
1054
1055
1056
1057
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1058
            use_tqdm_on_load=self.use_tqdm_on_load,
1059
1060
        )

1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1074
        dictionary from the engine.
1075
1076
        """
        if self.speculative_config is None:
1077
1078
            return None

1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
        speculative_config = SpeculativeConfig.from_dict(
            self.speculative_config)

        return speculative_config

1093
1094
1095
1096
1097
1098
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.
1099

1100
1101
1102
        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
1103

1104
1105
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.
1106

1107
1108
1109
        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1110
1111
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update()
1112
1113
1114
1115

        device_config = DeviceConfig(device=self.device)
        model_config = self.create_model_config()

1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
            self._set_default_args_v1(usage_context)
        else:
            self._set_default_args_v0(model_config)
1138

1139
        assert self.enable_chunked_prefill is not None
1140

1141
        cache_config = CacheConfig(
1142
            block_size=self.block_size,
1143
1144
1145
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1146
            is_attention_free=model_config.is_attention_free,
1147
1148
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1149
            enable_prefix_caching=self.enable_prefix_caching,
1150
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1151
            cpu_offload_gb=self.cpu_offload_gb,
1152
            calculate_kv_scales=self.calculate_kv_scales,
1153
        )
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1166
        parallel_config = ParallelConfig(
1167
1168
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1169
            data_parallel_size=self.data_parallel_size,
1170
            enable_expert_parallel=self.enable_expert_parallel,
1171
1172
1173
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1174
            placement_group=placement_group,
1175
1176
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1177
            worker_extension_cls=self.worker_extension_cls,
1178
        )
1179

1180
        speculative_config = self.create_speculative_config(
1181
1182
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1183
            enable_chunked_prefill=self.enable_chunked_prefill,
王敏's avatar
王敏 committed
1184
            disable_log_stats=self.disable_log_stats,
1185
1186
        )

1187
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1188
        # If the feature combo become valid
1189
1190
1191
1192
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1193
1194
1195
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1196
1197
1198
1199
1200
1201
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1202
1203
1204
1205
1206
1207
1208
1209
1210

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1211
        scheduler_config = SchedulerConfig(
1212
            runner_type=model_config.runner_type,
1213
1214
1215
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1216
            num_lookahead_slots=num_lookahead_slots,
1217
1218
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1219
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1220
            is_multimodal_model=model_config.is_multimodal_model,
1221
            preemption_mode=self.preemption_mode,
1222
            num_scheduler_steps=self.num_scheduler_steps,
1223
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1224
1225
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1226
            policy=self.scheduling_policy,
1227
            scheduler_cls=self.scheduler_cls,
1228
1229
1230
1231
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
        )
1232

1233
        lora_config = LoRAConfig(
1234
            bias_enabled=self.enable_lora_bias,
1235
1236
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1237
            fully_sharded_loras=self.fully_sharded_loras,
1238
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1239
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1240
1241
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
1242
1243
1244
            and self.max_cpu_loras > 0 else None,
            merge_lora=self.merge_lora,
            lora_target_modules=self.lora_target_modules) if self.enable_lora else None
1245

1246
1247
1248
1249
1250
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1251
1252
1253
1254
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1255
        load_config = self.create_load_config()
1256

1257
1258
1259
1260
1261
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1262
        decoding_config = DecodingConfig(
1263
1264
1265
1266
            guided_decoding_backend=self.guided_decoding_backend,
            reasoning_backend=self.reasoning_parser
            if self.enable_reasoning else None,
        )
1267

1268
1269
1270
1271
        show_hidden_metrics = False
        if self.show_hidden_metrics_for_version is not None:
            show_hidden_metrics = version._prev_minor_version_was(
                self.show_hidden_metrics_for_version)
1272

1273
1274
1275
1276
1277
1278
1279
1280
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1281
        observability_config = ObservabilityConfig(
1282
            show_hidden_metrics=show_hidden_metrics,
1283
1284
1285
1286
1287
1288
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1289

1290
        config = VllmConfig(
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1301
            prompt_adapter_config=prompt_adapter_config,
1302
            compilation_config=self.compilation_config,
1303
            kv_transfer_config=self.kv_transfer_config,
1304
            additional_config=self.additional_config,
1305
        )
1306

1307
1308
        return config

1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

        if (self.load_format == LoadFormat.TENSORIZER.value
                or self.load_format == LoadFormat.SHARDED_STATE.value):
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

1328
        if self.preemption_mode != SchedulerConfig.preemption_mode:
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

1339
        if self.scheduling_policy != SchedulerConfig.policy:
1340
1341
1342
1343
            _raise_or_fallback(feature_name="--scheduling-policy",
                               recommend_to_remove=False)
            return False

1344
        if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
1345
1346
1347
1348
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

1349
        if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
1350
1351
1352
1353
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

1354
1355
1356
1357
1358
1359
1360
        # remove backend options when doing this check
        if self.guided_decoding_backend.split(':')[0] \
            not in get_args(GuidedDecodingBackendV1):
            _raise_or_fallback(
                feature_name=
                f"--guided-decoding-backend={self.guided_decoding_backend}",
                recommend_to_remove=False)
1361
1362
1363
            return False

        # Need at least Ampere for now (FA support required).
1364
1365
1366
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1367
1368
        from vllm.platforms import current_platform
        if (current_platform.is_cuda()
1369
                and current_platform.get_device_capability()
1370
1371
1372
1373
1374
1375
1376
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
1377
1378
1379
1380
1381
1382
1383
            fp8_attention = self.kv_cache_dtype.startswith("fp8")
            will_use_fa = (
                current_platform.is_cuda()
                and not envs.is_set("VLLM_ATTENTION_BACKEND")
            ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
            supported = False
            if fp8_attention and will_use_fa:
1384
                from vllm.attention.utils.fa_utils import (
1385
1386
                    flash_attn_supports_fp8)
                supported = flash_attn_supports_fp8()
xiabo's avatar
xiabo committed
1387
1388
1389
1390
1391

            int8_attention = self.kv_cache_dtype.startswith("int8")
            if int8_attention:
                supported = True
                
1392
1393
1394
1395
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

        # Only Fp16 and Bf16 dtypes since we only support FA.
        V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
        if model_config.dtype not in V1_SUPPORTED_DTYPES:
            _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
                               recommend_to_remove=False)
            return False

        # Some quantization is not compatible with torch.compile.
1411
        V1_UNSUPPORTED_QUANT = ["gguf"]
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
        if model_config.quantization in V1_UNSUPPORTED_QUANT:
            _raise_or_fallback(
                feature_name=f"--quantization {model_config.quantization}",
                recommend_to_remove=False)
            return False

        # No Embedding Models so far.
        if model_config.task not in ["generate"]:
            _raise_or_fallback(feature_name=f"--task {model_config.task}",
                               recommend_to_remove=False)
            return False

        # No Mamba or Encoder-Decoder so far.
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
1432
                != SchedulerConfig.max_num_partial_prefills
1433
                or self.max_long_partial_prefills
1434
                != SchedulerConfig.max_long_partial_prefills):
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

        # Only Ngram speculative decoding so far.
1446
        is_ngram_enabled = False
1447
        is_eagle_enabled = False
1448
        if self.speculative_config is not None:
1449
            # This is supported but experimental (handled below).
1450
1451
1452
1453
            speculative_method = self.speculative_config.get("method")
            if speculative_method:
                if speculative_method in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
1454
                elif speculative_method in ("eagle", "eagle3"):
1455
                    is_eagle_enabled = True
1456
            else:
1457
1458
1459
1460
1461
                speculative_model = self.speculative_config.get("model")
                if speculative_model in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
            if not (is_ngram_enabled or is_eagle_enabled):
                # Other speculative decoding methods are not supported yet.
1462
1463
1464
1465
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

1466
        # No XFormers so far.
1467
        V1_BACKENDS = [
1468
1469
1470
1471
1472
1473
1474
1475
1476
            "FLASH_ATTN_VLLM_V1",
            "FLASH_ATTN",
            "PALLAS",
            "PALLAS_VLLM_V1",
            "TRITON_ATTN_VLLM_V1",
            "TRITON_MLA",
            "FLASHMLA",
            "FLASHINFER",
            "FLASHINFER_VLLM_V1",
1477
1478
1479
1480
1481
1482
1483
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1484
1485
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1486
1487
1488
1489
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1490
1491
1492
        #############################################################
        # Experimental Features - allow users to opt in.

1493
1494
1495
1496
1497
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1498
1499
1500
        # PP is supported on V1 with Ray distributed executor,
        # but off for MP distributed executor for now.
        if (self.pipeline_parallel_size > 1
1501
1502
1503
                and self.distributed_executor_backend != "ray"):
            name = "Pipeline Parallelism without Ray distributed executor"
            _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1504
1505
1506
            return False

        # ngram is supported on V1, but off by default for now.
1507
        if is_ngram_enabled and _warn_or_fallback("ngram"):
1508
1509
            return False

1510
1511
        # Eagle is under development, so we don't support it yet.
        if is_eagle_enabled and _warn_or_fallback("Eagle"):
1512
1513
1514
1515
1516
            return False

        # Non-CUDA is supported on V1, but off by default for now.
        not_cuda = not current_platform.is_cuda()
        if not_cuda and _warn_or_fallback(  # noqa: SIM103
1517
                current_platform.device_name):
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
            return False
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                from vllm.platforms import current_platform
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1540
                use_spec_decode = self.speculative_config is not None
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching:
            # Disable prefix caching for multimodal models for VLLM_V0.
            if model_config.is_multimodal_model:
                logger.warning(
                    "--enable-prefix-caching is not supported for multimodal "
                    "models in V0 and has been disabled.")
                self.enable_prefix_caching = False

            # VLLM_V0 only supports builtin hash algo for prefix caching.
1578
            if self.prefix_caching_hash_algo == "sha256":
1579
1580
1581
                raise ValueError(
                    "sha256 is not supported for prefix caching in V0 engine. "
                    "Please use 'builtin'.")
1582
1583
1584
1585
1586
1587
1588

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

    def _set_default_args_v1(self, usage_context: UsageContext) -> None:
        """Set Default Arguments for V1 Engine."""
1589

1590
1591
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
1592
1593
1594
1595
1596

        # V1 enables prefix caching by default.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = True

1597
1598
1599
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1600
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1601

1602
1603
        # When no user override, set the default values based on the usage
        # context.
1604
        # Use different default values for different hardware.
1605
1606
1607
1608
1609
1610
1611
1612

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
            from vllm.platforms import current_platform
1613
            device_memory = current_platform.get_device_total_memory()
1614
1615
        except Exception:
            # This is only used to set default_max_num_batched_tokens
1616
            device_memory = 0
1617

1618
1619
        if device_memory >= 70 * GiB_bytes:
            # For GPUs like H100 and MI300x, use larger default values.
1620
1621
1622
1623
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1624
            default_max_num_seqs = 1024
1625
1626
1627
1628
1629
1630
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1631
            default_max_num_seqs = 256
1632

1633
        use_context_value = usage_context.value if usage_context else None
1634
1635
1636
1637
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1638
            logger.debug(
1639
                "Setting max_num_batched_tokens to %d for %s usage context.",
1640
                self.max_num_batched_tokens, use_context_value)
1641

1642
1643
1644
1645
1646
        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1647

1648

1649
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1650
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1651
    """Arguments for asynchronous vLLM engine."""
1652
    disable_log_requests: bool = False
1653
1654

    @staticmethod
1655
1656
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1657
1658
1659
1660
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1661
1662
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1663
1664
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1665
                            help='Disable logging requests.')
1666
1667
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update(parser)
1668
        return parser
1669
1670


1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
    match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
    if match:
        decimal_multiplier = {
            'k': 10**3,
            'm': 10**6,
            'g': 10**9,
        }
        binary_multiplier = {
            'K': 2**10,
            'M': 2**20,
            'G': 2**30,
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
                raise argparse.ArgumentTypeError("Decimals are not allowed " \
                f"with binary suffixes like {suffix}. Did you mean to use " \
                f"{number}{suffix.lower()} instead?") from e

    # Regular plain number.
    return int(value)


1739
1740
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1741
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1742
1743
1744


def _async_engine_args_parser():
1745
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
zhuwenwen's avatar
zhuwenwen committed
1746
                                        async_args_only=True)