arg_utils.py 81.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import argparse
4
import dataclasses
5
import json
6
import re
7
import threading
8
from dataclasses import MISSING, dataclass, fields
9
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
10
                    Tuple, Type, Union, cast, get_args, get_origin)
11

12
13
import torch

14
import vllm.envs as envs
15
from vllm import version
16
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
17
18
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
19
20
21
                         ModelConfig, ModelImpl, ObservabilityConfig,
                         ParallelConfig, PoolerConfig, PromptAdapterConfig,
                         SchedulerConfig, SpeculativeConfig, TaskOption,
22
                         TokenizerPoolConfig, VllmConfig, get_attr_docs)
23
from vllm.executor.executor_base import ExecutorBase
24
from vllm.logger import init_logger
25
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
26
from vllm.plugins import load_general_plugins
27
from vllm.reasoning import ReasoningParserManager
28
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
29
from vllm.transformers_utils.utils import check_gguf_file
30
from vllm.usage.usage_lib import UsageContext
31
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
32

33
if TYPE_CHECKING:
34
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
35

36
37
logger = init_logger(__name__)

38
39
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

40
41
42
43
44
45
46
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "tpu",
    "xpu",
47
    "hpu",
48
49
]

50

51
52
53
54
55
56
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


57
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
58
59
60
61
62
63
64
65
66
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
67
68
69
70
71
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
72
73
74
75
76
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
77
78

        try:
79
            parsed_value = int(value)
80
81
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
82
83
84
85
86
87
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
88
89
90
91

    return out_dict


92
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
93
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
94
    """Arguments for vLLM engine."""
95
    model: str = 'facebook/opt-125m'
96
    served_model_name: Optional[Union[str, List[str]]] = None
97
    tokenizer: Optional[str] = None
98
    hf_config_path: Optional[str] = None
99
    task: TaskOption = "auto"
100
    skip_tokenizer_init: bool = False
101
    tokenizer_mode: str = 'auto'
102
    trust_remote_code: bool = False
103
    allowed_local_media_path: str = ""
104
105
    download_dir: Optional[str] = LoadConfig.download_dir
    load_format: str = LoadConfig.load_format
106
    config_format: ConfigFormat = ConfigFormat.AUTO
107
    dtype: str = 'auto'
108
    kv_cache_dtype: str = 'auto'
109
    seed: Optional[int] = None
110
    max_model_len: Optional[int] = None
111
112
113
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
114
115
    distributed_executor_backend: Optional[Union[
        str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
116
    # number of P/D disaggregation (or other disaggregation) workers
117
118
119
120
121
122
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
    data_parallel_size: int = ParallelConfig.data_parallel_size
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
    max_parallel_loading_workers: Optional[
        int] = ParallelConfig.max_parallel_loading_workers
123
    block_size: Optional[int] = None
124
    enable_prefix_caching: Optional[bool] = None
125
    prefix_caching_hash_algo: str = "builtin"
126
    disable_sliding_window: bool = False
127
    disable_cascade_attn: bool = False
128
    use_v2_block_manager: bool = True
129
130
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
131
    gpu_memory_utilization: float = 0.90
132
    max_num_batched_tokens: Optional[int] = None
133
134
135
    max_num_partial_prefills: Optional[int] = 1
    max_long_partial_prefills: Optional[int] = 1
    long_prefill_token_threshold: Optional[int] = 0
136
    max_num_seqs: Optional[int] = None
137
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
138
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
139
    revision: Optional[str] = None
140
    code_revision: Optional[str] = None
141
    rope_scaling: Optional[Dict[str, Any]] = None
142
    rope_theta: Optional[float] = None
143
    hf_token: Optional[Union[bool, str]] = None
144
    hf_overrides: Optional[HfOverrides] = None
145
    tokenizer_revision: Optional[str] = None
146
    quantization: Optional[str] = None
147
    enforce_eager: Optional[bool] = None
148
    max_seq_len_to_capture: int = 8192
149
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
150
    tokenizer_pool_size: int = 0
151
152
153
154
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
155
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
156
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
157
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
158
    disable_mm_preprocessor_cache: bool = False
159
    enable_lora: bool = False
160
    enable_lora_bias: bool = False
161
162
    max_loras: int = 1
    max_lora_rank: int = 16
163
164
165
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
166
    fully_sharded_loras: bool = False
167
    lora_extra_vocab_size: int = 256
168
    long_lora_scaling_factors: Optional[Tuple[float]] = None
169
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
170
    max_cpu_loras: Optional[int] = None
171
172
    merge_lora: bool = False
    lora_target_modules: Optional[List[str]] = None
173
    device: str = 'auto'
174
    num_scheduler_steps: int = 1
175
    multi_step_stream_outputs: bool = True
176
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
177
    num_gpu_blocks_override: Optional[int] = None
178
    num_lookahead_slots: int = 0
179
180
181
182
    model_loader_extra_config: Optional[
        dict] = LoadConfig.model_loader_extra_config
    ignore_patterns: Optional[Union[str,
                                    List[str]]] = LoadConfig.ignore_patterns
183
    preemption_mode: Optional[str] = None
184

185
    scheduler_delay_factor: float = 0.0
186
    enable_chunked_prefill: Optional[bool] = None
187
    disable_chunked_mm_input: bool = False
188

189
    guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
190
    logits_processor_pattern: Optional[str] = None
191

192
    speculative_config: Optional[Dict[str, Any]] = None
zhuwenwen's avatar
zhuwenwen committed
193
    num_speculative_heads: Optional[int] = None
194

195
    qlora_adapter_name_or_path: Optional[str] = None
196
    show_hidden_metrics_for_version: Optional[str] = None
197
    otlp_traces_endpoint: Optional[str] = None
198
    collect_detailed_traces: Optional[str] = None
199
    disable_async_output_proc: bool = False
200
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
201
    scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
202

203
    override_neuron_config: Optional[Dict[str, Any]] = None
204
    override_pooler_config: Optional[PoolerConfig] = None
205
    compilation_config: Optional[CompilationConfig] = None
206
207
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
208

209
    kv_transfer_config: Optional[KVTransferConfig] = None
210

211
    generation_config: Optional[str] = "auto"
212
    override_generation_config: Optional[Dict[str, Any]] = None
213
    enable_sleep_mode: bool = False
214
    model_impl: str = "auto"
215

216
    calculate_kv_scales: Optional[bool] = None
217

218
    additional_config: Optional[Dict[str, Any]] = None
219
220
    enable_reasoning: Optional[bool] = None
    reasoning_parser: Optional[str] = None
221
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
王敏's avatar
王敏 committed
222

王敏's avatar
王敏 committed
223

224
    def __post_init__(self):
225
        if not self.tokenizer:
226
            self.tokenizer = self.model
227

228
229
230
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
231
        if isinstance(self.compilation_config, (int, dict)):
232
233
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
234

235
        # Setup plugins
236
237
        from vllm.plugins import load_general_plugins
        load_general_plugins()
238
239

    @staticmethod
240
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
241
        """Shared CLI arguments for vLLM engine."""
242

243
244
245
246
        def is_type_in_union(cls: type[Any], type: type[Any]) -> bool:
            """Check if the class is a type in a union type."""
            return get_origin(cls) is Union and type in get_args(cls)

247
248
        def is_optional(cls: type[Any]) -> bool:
            """Check if the class is an optional type."""
249
            return is_type_in_union(cls, type(None))
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267

        def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
            cls_docs = get_attr_docs(cls)
            kwargs = {}
            for field in fields(cls):
                name = field.name
                # One of these will always be present
                default = (field.default_factory
                           if field.default is MISSING else field.default)
                kwargs[name] = {"default": default, "help": cls_docs[name]}
                # When using action="store_true"
                # add_argument doesn't accept type
                if field.type is bool:
                    continue
                # Handle optional fields
                if is_optional(field.type):
                    kwargs[name]["type"] = nullable_str
                    continue
268
269
270
271
                # Handle str in union fields
                if is_type_in_union(field.type, str):
                    kwargs[name]["type"] = str
                    continue
272
273
274
                kwargs[name]["type"] = field.type
            return kwargs

275
        # Model arguments
276
277
278
        parser.add_argument(
            '--model',
            type=str,
279
            default=EngineArgs.model,
280
            help='Name or path of the huggingface model to use.')
281
282
283
284
285
286
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
287
            'multiple tasks. When the model only supports one task, ``"auto"`` '
288
289
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
290
291
        parser.add_argument(
            '--tokenizer',
292
            type=nullable_str,
293
            default=EngineArgs.tokenizer,
294
295
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
296
297
298
299
300
301
        parser.add_argument(
            "--hf-config-path",
            type=nullable_str,
            default=EngineArgs.hf_config_path,
            help='Name or path of the huggingface config to use. '
            'If unspecified, model name or path will be used.')
302
303
304
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
305
306
307
            help='Skip initialization of tokenizer and detokenizer. '
            'Expects valid prompt_token_ids and None for prompt from '
            'the input. The generated output will contain token ids.')
Jasmond L's avatar
Jasmond L committed
308
309
        parser.add_argument(
            '--revision',
310
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
311
            default=None,
312
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
313
314
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
315
316
        parser.add_argument(
            '--code-revision',
317
            type=nullable_str,
318
            default=None,
319
            help='The specific revision to use for the model code on '
320
321
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
322
323
        parser.add_argument(
            '--tokenizer-revision',
324
            type=nullable_str,
325
            default=None,
326
327
328
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
329
330
331
332
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
333
            choices=['auto', 'slow', 'mistral', 'custom'],
334
335
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
336
            'always use the slow tokenizer. \n* '
337
338
339
            '"mistral" will always use the `mistral_common` tokenizer. \n* '
            '"custom" will use --tokenizer to select the '
            'preregistered tokenizer.')
340
341
        parser.add_argument('--trust-remote-code',
                            action='store_true',
342
                            help='Trust remote code from huggingface.')
343
344
345
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
346
347
348
349
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
        load_group.add_argument('--load-format',
                                choices=[f.value for f in LoadFormat],
                                **load_kwargs["load_format"])
        load_group.add_argument('--download-dir',
                                **load_kwargs["download_dir"])
        load_group.add_argument('--model-loader-extra-config',
                                **load_kwargs["model_loader_extra_config"])
        load_group.add_argument('--use-tqdm-on-load',
                                action=argparse.BooleanOptionalAction,
                                **load_kwargs["use_tqdm_on_load"])

367
368
369
370
371
372
373
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
374
375
376
377
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
378
379
380
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
381
382
383
384
385
386
387
388
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
389
390
391
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
392
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
393
            default=EngineArgs.kv_cache_dtype,
394
            help='Data type for kv cache storage. If "auto", will use model '
395
396
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
397
        parser.add_argument('--max-model-len',
398
                            type=human_readable_int,
399
                            default=EngineArgs.max_model_len,
400
                            help='Model context length. If unspecified, will '
401
402
403
404
405
                            'be automatically derived from the model config. '
                            'Supports k/m/g/K/M/G in human-readable format.\n'
                            'Examples:\n'
                            '- 1k → 1000\n'
                            '- 1K → 1024\n')
406
407
408
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
409
            default=DecodingConfig.guided_decoding_backend,
410
            help='Which engine will be used for guided decoding'
411
            ' (JSON schema / regex etc) by default. Currently support '
412
413
414
            'https://github.com/mlc-ai/xgrammar and '
            'https://github.com/guidance-ai/llguidance.'
            'Valid backend values are "xgrammar", "guidance", and "auto". '
415
            'With "auto", we will make opinionated choices based on request '
416
            'contents and what the backend libraries currently support, so '
417
            'the behavior is subject to change in each release.')
418
419
420
421
422
423
424
425
        parser.add_argument(
            '--logits-processor-pattern',
            type=nullable_str,
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
426
427
428
429
430
431
432
433
434
435
436
437
        parser.add_argument(
            '--model-impl',
            type=str,
            default=EngineArgs.model_impl,
            choices=[f.value for f in ModelImpl],
            help='Which implementation of the model to use.\n\n'
            '* "auto" will try to use the vLLM implementation if it exists '
            'and fall back to the Transformers implementation if no vLLM '
            'implementation is available.\n'
            '* "vllm" will use the vLLM model implementation.\n'
            '* "transformers" will use the Transformers model '
            'implementation.\n')
438
        # Parallel arguments
439
440
441
442
443
444
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
445
            '--distributed-executor-backend',
446
            choices=['ray', 'mp', 'uni', 'external_launcher'],
447
448
449
450
451
452
453
454
455
            **parallel_kwargs["distributed_executor_backend"])
        parallel_group.add_argument(
            '--pipeline-parallel-size', '-pp',
            **parallel_kwargs["pipeline_parallel_size"])
        parallel_group.add_argument('--tensor-parallel-size', '-tp',
                                    **parallel_kwargs["tensor_parallel_size"])
        parallel_group.add_argument('--data-parallel-size', '-dp',
                                    **parallel_kwargs["data_parallel_size"])
        parallel_group.add_argument(
456
457
            '--enable-expert-parallel',
            action='store_true',
458
459
            **parallel_kwargs["enable_expert_parallel"])
        parallel_group.add_argument(
460
            '--max-parallel-loading-workers',
461
462
            **parallel_kwargs["max_parallel_loading_workers"])
        parallel_group.add_argument(
463
464
            '--ray-workers-use-nsight',
            action='store_true',
465
466
467
468
469
            **parallel_kwargs["ray_workers_use_nsight"])
        parallel_group.add_argument(
            '--disable-custom-all-reduce',
            action='store_true',
            **parallel_kwargs["disable_custom_all_reduce"])
470
        # KV cache arguments
471
472
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
473
                            default=EngineArgs.block_size,
474
                            choices=[8, 16, 32, 64, 128],
475
                            help='Token block size for contiguous chunks of '
476
                            'tokens. This is ignored on neuron devices and '
477
                            'set to ``--max-model-len``. On CUDA devices, '
478
479
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
480

481
482
483
484
485
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
486
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
487
        )
488
489
490
491
492
493
494
        parser.add_argument(
            "--prefix-caching-hash-algo",
            type=str,
            choices=["builtin", "sha256"],
            default=EngineArgs.prefix_caching_hash_algo,
            help="Set the hash algorithm for prefix caching. "
            "Options are 'builtin' (Python's built-in hash) or 'sha256' "
495
            "(collision resistant but with certain overheads).",
496
        )
497
498
499
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
500
                            'capping to sliding window size.')
501
502
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
503
                            default=True,
504
505
506
507
508
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
509
510
511
512
513
514
515
516
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
517

518
519
520
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
521
                            help='Random seed for operations.')
522
        parser.add_argument('--swap-space',
523
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
524
                            default=EngineArgs.swap_space,
525
                            help='CPU swap space size (GiB) per GPU.')
526
527
528
529
530
531
532
533
534
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
535
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
536
            'which requires at least 26GB GPU memory. Note that this '
537
            'requires fast CPU-GPU interconnect, as part of the model is '
538
539
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
540
541
542
543
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
544
545
546
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
547
548
549
550
551
552
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
553
        parser.add_argument(
554
            '--num-gpu-blocks-override',
555
556
557
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
558
            ' of GPU blocks. Used for testing preemption.')
559
560
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
561
                            default=EngineArgs.max_num_batched_tokens,
562
563
                            help='Maximum number of batched tokens per '
                            'iteration.')
564
565
566
567
568
        parser.add_argument(
            "--max-num-partial-prefills",
            type=int,
            default=EngineArgs.max_num_partial_prefills,
            help="For chunked prefill, the max number of concurrent \
569
            partial prefills.")
570
571
572
573
574
575
576
577
        parser.add_argument(
            "--max-long-partial-prefills",
            type=int,
            default=EngineArgs.max_long_partial_prefills,
            help="For chunked prefill, the maximum number of prompts longer "
            "than --long-prefill-token-threshold that will be prefilled "
            "concurrently. Setting this less than --max-num-partial-prefills "
            "will allow shorter prompts to jump the queue in front of longer "
578
            "prompts in some cases, improving latency.")
579
580
581
582
583
        parser.add_argument(
            "--long-prefill-token-threshold",
            type=float,
            default=EngineArgs.long_prefill_token_threshold,
            help="For chunked prefill, a request is considered long if the "
584
            "prompt is longer than this number of tokens.")
585
586
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
587
                            default=EngineArgs.max_num_seqs,
588
                            help='Maximum number of sequences per iteration.')
589
590
591
592
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
593
594
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
595
596
        parser.add_argument('--disable-log-stats',
                            action='store_true',
597
                            help='Disable logging statistics.')
598
599
600
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
601
                            type=nullable_str,
602
                            choices=[*QUANTIZATION_METHODS, None],
603
                            default=EngineArgs.quantization,
604
605
606
607
608
609
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
610
611
612
613
614
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
615
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
616
617
618
619
620
621
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
622
623
624
625
626
627
628
629
630
631
        parser.add_argument(
            '--hf-token',
            type=str,
            nargs='?',
            const=True,
            default=None,
            help='The token to use as HTTP bearer authorization'
            ' for remote files. If `True`, will use the token '
            'generated when running `huggingface-cli login` '
            '(stored in `~/.huggingface`).')
632
633
634
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
635
                            help='Extra arguments for the HuggingFace config. '
636
637
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
638
        parser.add_argument('--enforce-eager',
zhuwenwen's avatar
zhuwenwen committed
639
                            action='store_true',
640
641
642
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
643
        parser.add_argument('--max-seq-len-to-capture',
644
645
646
647
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
648
649
650
651
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
652
653
654
655
656
657
658
659
660
661
662
663
664
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
665
                            type=nullable_str,
666
667
668
669
670
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
671
672
673
674
675
676
677

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
678
            # MultiModalConfig.get_default_limit_per_prompt
679
680
681
682
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
683
684
                  'images and 2 videos per prompt. Defaults to '
                  '1 (V0) or 999 (V1) for each modality.'))
685
686
687
688
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
689
            help=('Overrides for the multimodal input mapping/processing, '
690
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
691
        parser.add_argument(
692
            '--disable-mm-preprocessor-cache',
693
            action='store_true',
694
695
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
696

697
698
699
700
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
701
702
703
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
704
705
706
707
708
709
710
711
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
712
713
714
715
716
717
718
719
        parser.add_argument('--merge-lora',
                            type=bool,
                            default=False,
                            help='If set to True, the weights of the base layer will be merged with the weights of Lora.')
        parser.add_argument('--lora-target-modules',
                            nargs='*',
                            default=None,
                            help='List of lora module name, If not specified, modules will be chosen according to the model architecture.')
720
721
722
723
724
725
726
727
728
729
730
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
731
            choices=['auto', 'float16', 'bfloat16'],
732
733
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
734
735
736
737
738
739
740
741
742
743
744
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
745
746
747
748
749
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
750
                  'Must be >= than max_loras.'))
751
752
753
754
755
756
757
758
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
759
760
761
762
763
764
765
766
767
768
769
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
770
771
772
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
773
                            choices=DEVICE_OPTIONS,
774
                            help='Device type for vLLM execution.')
775
776
777
778
779
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
780

781
782
        parser.add_argument(
            '--multi-step-stream-outputs',
783
784
785
786
787
788
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
789
790
791
792
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
793
            help='Apply a delay (of delay factor multiplied by previous '
794
            'prompt latency) before scheduling next prompt.')
795
796
        parser.add_argument(
            '--enable-chunked-prefill',
797
798
799
800
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
801
            help='If set, the prefill requests can be chunked based on the '
802
            'max_num_batched_tokens.')
803
        parser.add_argument('--speculative-config',
804
                            type=json.loads,
805
806
807
                            default=None,
                            help='The configurations for speculative decoding.'
                            ' Should be a JSON string.')
zhuwenwen's avatar
zhuwenwen committed
808
809
810
811
812
813
        parser.add_argument(
            '--num-speculative-heads',
            type=int,
            default=EngineArgs.num_speculative_heads,
            help='The number of speculative heads to sample from '
                 'the draft model in speculative decoding.')
814
815
816
817
818
819
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
820
            "Default to `original/**/*` to avoid repeated loading of llama's "
821
            "checkpoints.")
822
        parser.add_argument(
823
            '--preemption-mode',
824
825
            type=str,
            default=None,
826
827
828
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
829

830
831
832
833
834
835
836
837
838
839
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
840
            "same as the ``--model`` argument. Noted that this name(s) "
841
            "will also be used in `model_name` tag content of "
842
            "prometheus metrics, if multiple names provided, metrics "
843
            "tag will take the first one.")
844
845
846
847
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
848

849
850
851
852
853
854
855
856
857
858
859
860
        parser.add_argument('--show-hidden-metrics-for-version',
                            type=str,
                            default=None,
                            help='Enable deprecated Prometheus metrics that '
                            'have been hidden since the specified version. '
                            'For example, if a previously deprecated metric '
                            'has been hidden since the v0.7.0 release, you '
                            'use --show-hidden-metrics-for-version=0.7 as a '
                            'temporary escape hatch while you migrate to new '
                            'metrics. The metric is likely to be removed '
                            'completely in an upcoming release.')

861
862
863
864
865
        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
866
867
868
869
870
871
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
872
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
873
874
875
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
876

877
878
879
880
881
882
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
883

884
885
886
887
888
889
890
891
892
893
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

894
895
896
897
898
899
900
        parser.add_argument(
            '--scheduler-cls',
            default=EngineArgs.scheduler_cls,
            help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
            'is the default scheduler. Can be a class directly or the path to '
            'a class of form "mod.custom_class".')

901
902
        parser.add_argument(
            '--override-neuron-config',
903
            type=json.loads,
904
            default=None,
905
            help="Override or set neuron device configuration. "
906
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
907
        parser.add_argument(
908
909
            '--override-pooler-config',
            type=PoolerConfig.from_json,
910
            default=None,
911
            help="Override or set the pooling method for pooling models. "
912
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
913

914
915
916
917
918
919
920
921
922
923
924
925
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
926
927
928
929
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
930

931
932
933
934
935
936
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

937
938
939
940
941
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
942
943
944
945
946
947
948
        parser.add_argument(
            '--worker-extension-cls',
            type=str,
            default="",
            help='The worker extension class on top of the worker cls, '
            'it is useful if you just want to add new functions to the worker '
            'class without changing the existing functions.')
949
950
951
        parser.add_argument(
            "--generation-config",
            type=nullable_str,
952
            default="auto",
953
            help="The folder path to the generation config. "
954
955
956
957
958
            "Defaults to 'auto', the generation config will be loaded from "
            "model path. If set to 'vllm', no generation config is loaded, "
            "vLLM defaults will be used. If set to a folder path, the "
            "generation config will be loaded from the specified folder path. "
            "If `max_new_tokens` is specified in generation config, then "
959
960
961
962
963
964
965
966
967
968
969
970
            "it sets a server-wide limit on the number of output tokens "
            "for all requests.")

        parser.add_argument(
            "--override-generation-config",
            type=json.loads,
            default=None,
            help="Overrides or sets generation config in JSON format. "
            "e.g. ``{\"temperature\": 0.5}``. If used with "
            "--generation-config=auto, the override parameters will be merged "
            "with the default config from the model. If generation-config is "
            "None, only the override parameters are used.")
971

972
973
974
975
976
977
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

978
979
980
981
982
983
984
985
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')
986

987
988
989
990
991
992
993
994
        parser.add_argument(
            "--additional-config",
            type=json.loads,
            default=None,
            help="Additional config for specified platform in JSON format. "
            "Different platforms may support different configs. Make sure the "
            "configs are valid for the platform you are using. The input format"
            " is like '{\"config_key\":\"config_value\"}'")
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006

        parser.add_argument(
            "--enable-reasoning",
            action="store_true",
            default=False,
            help="Whether to enable reasoning_content for the model. "
            "If enabled, the model will be able to generate reasoning content."
        )

        parser.add_argument(
            "--reasoning-parser",
            type=str,
1007
            choices=list(ReasoningParserManager.reasoning_parsers),
1008
1009
1010
1011
1012
1013
            default=None,
            help=
            "Select the reasoning parser depending on the model that you're "
            "using. This is used to parse the reasoning content into OpenAI "
            "API format. Required for ``--enable-reasoning``.")

1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        parser.add_argument(
            "--disable-cascade-attn",
            action="store_true",
            default=False,
            help="Disable cascade attention for V1. While cascade attention "
            "does not change the mathematical correctness, disabling it "
            "could be useful for preventing potential numerical issues. "
            "Note that even if this is set to False, cascade attention will be "
            "only used when the heuristic tells that it's beneficial.")

1024
1025
1026
1027
1028
        parser.add_argument(
            "--disable-chunked-mm-input",
            action=StoreBoolean,
            default=EngineArgs.disable_chunked_mm_input,
            nargs="?",
1029
            const="True",
1030
1031
1032
1033
1034
1035
1036
1037
            help="Disable multimodal input chunking attention for V1. "
            "If set to true and chunked prefill is enabled, we do not want to"
            " partially schedule a multimodal item. This ensures that if a "
            "request has a mixed prompt (like text tokens TTTT followed by "
            "image tokens IIIIIIIIII) where only some image tokens can be "
            "scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled "
            "as TTTT in one step and IIIIIIIIII in the next.")

1038
        return parser
1039
1040

    @classmethod
1041
    def from_cli_args(cls, args: argparse.Namespace):
1042
1043
1044
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
1045
1046
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
1047

1048
    def create_model_config(self) -> ModelConfig:
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

1060
        return ModelConfig(
1061
            model=self.model,
1062
            hf_config_path=self.hf_config_path,
1063
            task=self.task,
1064
1065
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
1066
1067
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1068
            allowed_local_media_path=self.allowed_local_media_path,
1069
1070
1071
1072
1073
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1074
            rope_theta=self.rope_theta,
1075
            hf_token=self.hf_token,
1076
            hf_overrides=self.hf_overrides,
1077
1078
1079
1080
1081
1082
1083
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
1084
            disable_cascade_attn=self.disable_cascade_attn,
1085
            skip_tokenizer_init=self.skip_tokenizer_init,
1086
            served_model_name=self.served_model_name,
1087
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1088
            use_async_output_proc=not self.disable_async_output_proc,
1089
            config_format=self.config_format,
1090
            mm_processor_kwargs=self.mm_processor_kwargs,
1091
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1092
1093
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1094
            logits_processor_pattern=self.logits_processor_pattern,
1095
            generation_config=self.generation_config,
1096
            override_generation_config=self.override_generation_config,
1097
            enable_sleep_mode=self.enable_sleep_mode,
1098
            model_impl=self.model_impl,
1099
        )
1100

1101
1102
    def create_load_config(self) -> LoadConfig:

1103
        if(self.qlora_adapter_name_or_path is not None) and \
1104
1105
            self.quantization != "bitsandbytes":
            raise ValueError(
1106
                "QLoRA adapter only support "
1107
1108
                f"'bitsandbytes' quantization, but got {self.quantization}")

1109
1110
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1111
1112
1113
1114
1115
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1116
            use_tqdm_on_load=self.use_tqdm_on_load,
1117
1118
        )

1119
1120
1121
1122
1123
1124
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
zhuwenwen's avatar
zhuwenwen committed
1125
        num_speculative_heads: Optional[int],
1126
1127
1128
1129
1130
1131
1132
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1133
        dictionary from the engine.
1134
1135
        """
        if self.speculative_config is None:
1136
1137
            return None

1138
1139
1140
1141
1142
1143
1144
1145
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
zhuwenwen's avatar
zhuwenwen committed
1146
            "num_speculative_heads": num_speculative_heads,
1147
1148
1149
1150
1151
1152
        })
        speculative_config = SpeculativeConfig.from_dict(
            self.speculative_config)

        return speculative_config

1153
1154
1155
1156
1157
1158
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.
1159

1160
1161
1162
        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
1163

1164
1165
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.
1166

1167
1168
1169
        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1170
1171
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update()
1172
1173
1174
1175

        device_config = DeviceConfig(device=self.device)
        model_config = self.create_model_config()

1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
            self._set_default_args_v1(usage_context)
        else:
            self._set_default_args_v0(model_config)
1198

1199
        assert self.enable_chunked_prefill is not None
1200

1201
        cache_config = CacheConfig(
1202
            block_size=self.block_size,
1203
1204
1205
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1206
            is_attention_free=model_config.is_attention_free,
1207
1208
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1209
            enable_prefix_caching=self.enable_prefix_caching,
1210
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1211
            cpu_offload_gb=self.cpu_offload_gb,
1212
            calculate_kv_scales=self.calculate_kv_scales,
1213
        )
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1226
        parallel_config = ParallelConfig(
1227
1228
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1229
            data_parallel_size=self.data_parallel_size,
1230
            enable_expert_parallel=self.enable_expert_parallel,
1231
1232
1233
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1234
1235
1236
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1237
            ),
1238
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1239
            placement_group=placement_group,
1240
1241
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1242
            worker_extension_cls=self.worker_extension_cls,
1243
        )
1244

1245
        speculative_config = self.create_speculative_config(
1246
1247
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1248
            enable_chunked_prefill=self.enable_chunked_prefill,
zhuwenwen's avatar
zhuwenwen committed
1249
            disable_log_stats=self.disable_log_stats,    
1250
            num_speculative_heads=self.num_speculative_heads
1251
1252
        )

1253
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1254
        # If the feature combo become valid
1255
1256
1257
1258
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1259
1260
1261
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1262
1263
1264
1265
1266
1267
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1268
1269
1270
1271
1272
1273
1274
1275
1276

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1277
        scheduler_config = SchedulerConfig(
1278
            runner_type=model_config.runner_type,
1279
1280
1281
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1282
            num_lookahead_slots=num_lookahead_slots,
1283
1284
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1285
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1286
            is_multimodal_model=model_config.is_multimodal_model,
1287
            preemption_mode=self.preemption_mode,
1288
            num_scheduler_steps=self.num_scheduler_steps,
1289
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1290
1291
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1292
            policy=self.scheduling_policy,
1293
            scheduler_cls=self.scheduler_cls,
1294
1295
1296
1297
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
        )
1298

1299
        lora_config = LoRAConfig(
1300
            bias_enabled=self.enable_lora_bias,
1301
1302
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1303
            fully_sharded_loras=self.fully_sharded_loras,
1304
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1305
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1306
1307
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
1308
1309
1310
            and self.max_cpu_loras > 0 else None,
            merge_lora=self.merge_lora,
            lora_target_modules=self.lora_target_modules) if self.enable_lora else None
1311

1312
1313
1314
1315
1316
1317
1318
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1319
1320
1321
1322
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1323
        load_config = self.create_load_config()
1324

1325
1326
1327
1328
1329
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1330
        decoding_config = DecodingConfig(
1331
1332
1333
1334
            guided_decoding_backend=self.guided_decoding_backend,
            reasoning_backend=self.reasoning_parser
            if self.enable_reasoning else None,
        )
1335

1336
1337
1338
1339
        show_hidden_metrics = False
        if self.show_hidden_metrics_for_version is not None:
            show_hidden_metrics = version._prev_minor_version_was(
                self.show_hidden_metrics_for_version)
1340

1341
1342
1343
1344
1345
1346
1347
1348
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1349
        observability_config = ObservabilityConfig(
1350
            show_hidden_metrics=show_hidden_metrics,
1351
1352
1353
1354
1355
1356
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1357

1358
        config = VllmConfig(
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1369
            prompt_adapter_config=prompt_adapter_config,
1370
            compilation_config=self.compilation_config,
1371
            kv_transfer_config=self.kv_transfer_config,
1372
            additional_config=self.additional_config,
1373
        )
1374

1375
1376
        return config

1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

        if (self.load_format == LoadFormat.TENSORIZER.value
                or self.load_format == LoadFormat.SHARDED_STATE.value):
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

        if self.preemption_mode != EngineArgs.preemption_mode:
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

        if self.scheduling_policy != EngineArgs.scheduling_policy:
            _raise_or_fallback(feature_name="--scheduling-policy",
                               recommend_to_remove=False)
            return False

        if self.num_scheduler_steps != EngineArgs.num_scheduler_steps:
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

        if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor:
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

        if self.additional_config != EngineArgs.additional_config:
            _raise_or_fallback(feature_name="--additional-config",
                               recommend_to_remove=False)
            return False

1427
        # Xgrammar and Guidance are supported.
1428
        SUPPORTED_GUIDED_DECODING = [
1429
1430
            "xgrammar", "xgrammar:disable-any-whitespace", "guidance",
            "guidance:disable-any-whitespace", "auto"
1431
        ]
1432
1433
1434
1435
1436
1437
        if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
            _raise_or_fallback(feature_name="--guided-decoding-backend",
                               recommend_to_remove=False)
            return False

        # Need at least Ampere for now (FA support required).
1438
1439
1440
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1441
1442
        from vllm.platforms import current_platform
        if (current_platform.is_cuda()
1443
                and current_platform.get_device_capability()
1444
1445
1446
1447
1448
1449
1450
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
            fp8_attention = self.kv_cache_dtype.startswith("fp8")
            will_use_fa = (
                current_platform.is_cuda()
                and not envs.is_set("VLLM_ATTENTION_BACKEND")
            ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
            supported = False
            if fp8_attention and will_use_fa:
                from vllm.vllm_flash_attn.fa_utils import (
                    flash_attn_supports_fp8)
                supported = flash_attn_supports_fp8()
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

        # Only Fp16 and Bf16 dtypes since we only support FA.
        V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
        if model_config.dtype not in V1_SUPPORTED_DTYPES:
            _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
                               recommend_to_remove=False)
            return False

        # Some quantization is not compatible with torch.compile.
1480
        V1_UNSUPPORTED_QUANT = ["gguf"]
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
        if model_config.quantization in V1_UNSUPPORTED_QUANT:
            _raise_or_fallback(
                feature_name=f"--quantization {model_config.quantization}",
                recommend_to_remove=False)
            return False

        # No Embedding Models so far.
        if model_config.task not in ["generate"]:
            _raise_or_fallback(feature_name=f"--task {model_config.task}",
                               recommend_to_remove=False)
            return False

        # No Mamba or Encoder-Decoder so far.
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
                != EngineArgs.max_num_partial_prefills
                or self.max_long_partial_prefills
1503
                != EngineArgs.max_long_partial_prefills):
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

        # Only Ngram speculative decoding so far.
1515
        is_ngram_enabled = False
1516
        is_eagle_enabled = False
1517
        if self.speculative_config is not None:
1518
            # This is supported but experimental (handled below).
1519
1520
1521
1522
1523
1524
            speculative_method = self.speculative_config.get("method")
            if speculative_method:
                if speculative_method in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
                elif speculative_method == "eagle":
                    is_eagle_enabled = True
1525
            else:
1526
1527
1528
1529
1530
                speculative_model = self.speculative_config.get("model")
                if speculative_model in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
            if not (is_ngram_enabled or is_eagle_enabled):
                # Other speculative decoding methods are not supported yet.
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

        # No Disaggregated Prefill so far.
        if self.kv_transfer_config != EngineArgs.kv_transfer_config:
            _raise_or_fallback(feature_name="--kv-transfer-config",
                               recommend_to_remove=False)
            return False

        # No FlashInfer or XFormers so far.
        V1_BACKENDS = [
            "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
1544
            "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
1545
1546
1547
1548
1549
1550
1551
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1552
1553
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1554
1555
1556
1557
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1558
1559
1560
        #############################################################
        # Experimental Features - allow users to opt in.

1561
1562
1563
1564
1565
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1566
1567
1568
        # PP is supported on V1 with Ray distributed executor,
        # but off for MP distributed executor for now.
        if (self.pipeline_parallel_size > 1
1569
1570
1571
                and self.distributed_executor_backend != "ray"):
            name = "Pipeline Parallelism without Ray distributed executor"
            _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1572
1573
1574
            return False

        # ngram is supported on V1, but off by default for now.
1575
        if is_ngram_enabled and _warn_or_fallback("ngram"):
1576
1577
            return False

1578
1579
        # Eagle is under development, so we don't support it yet.
        if is_eagle_enabled and _warn_or_fallback("Eagle"):
1580
1581
1582
1583
1584
            return False

        # Non-CUDA is supported on V1, but off by default for now.
        not_cuda = not current_platform.is_cuda()
        if not_cuda and _warn_or_fallback(  # noqa: SIM103
1585
                current_platform.device_name):
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
            return False
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                from vllm.platforms import current_platform
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1608
                use_spec_decode = self.speculative_config is not None
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching:
            # Disable prefix caching for multimodal models for VLLM_V0.
            if model_config.is_multimodal_model:
                logger.warning(
                    "--enable-prefix-caching is not supported for multimodal "
                    "models in V0 and has been disabled.")
                self.enable_prefix_caching = False

            # VLLM_V0 only supports builtin hash algo for prefix caching.
            if self.prefix_caching_hash_algo is None:
                self.prefix_caching_hash_algo = "builtin"
            elif self.prefix_caching_hash_algo == "sha256":
                raise ValueError(
                    "sha256 is not supported for prefix caching in V0 engine. "
                    "Please use 'builtin'.")
1652
1653
1654
1655
1656
1657
1658

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

    def _set_default_args_v1(self, usage_context: UsageContext) -> None:
        """Set Default Arguments for V1 Engine."""
1659

1660
1661
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
1662
1663
1664
1665
1666

        # V1 enables prefix caching by default.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = True

1667
1668
1669
1670
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
            self.prefix_caching_hash_algo = "builtin"

1671
1672
1673
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1674
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1675

1676
1677
        # When no user override, set the default values based on the usage
        # context.
1678
        # Use different default values for different hardware.
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
            from vllm.platforms import current_platform
            device_name = current_platform.get_device_name().lower()
        except Exception:
            # This is only used to set default_max_num_batched_tokens
            device_name = "no-device"

1692
1693
1694
1695
1696
1697
        if "h100" in device_name or "h200" in device_name:
            # For H100 and H200, we use larger default values.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1698
            default_max_num_seqs = 1024
1699
1700
1701
1702
1703
1704
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1705
            default_max_num_seqs = 256
1706

1707
        use_context_value = usage_context.value if usage_context else None
1708
1709
1710
1711
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1712
            logger.debug(
1713
                "Setting max_num_batched_tokens to %d for %s usage context.",
1714
                self.max_num_batched_tokens, use_context_value)
1715

1716
1717
1718
1719
1720
        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1721

1722

1723
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1724
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1725
    """Arguments for asynchronous vLLM engine."""
1726
    disable_log_requests: bool = False
1727
1728

    @staticmethod
1729
1730
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1731
1732
1733
1734
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1735
1736
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1737
1738
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1739
                            help='Disable logging requests.')
1740
1741
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update(parser)
1742
        return parser
1743
1744


1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
    match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
    if match:
        decimal_multiplier = {
            'k': 10**3,
            'm': 10**6,
            'g': 10**9,
        }
        binary_multiplier = {
            'K': 2**10,
            'M': 2**20,
            'G': 2**30,
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
                raise argparse.ArgumentTypeError("Decimals are not allowed " \
                f"with binary suffixes like {suffix}. Did you mean to use " \
                f"{number}{suffix.lower()} instead?") from e

    # Regular plain number.
    return int(value)


1813
1814
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1815
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1816
1817
1818


def _async_engine_args_parser():
1819
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
zhuwenwen's avatar
zhuwenwen committed
1820
                                        async_args_only=True)