arg_utils.py 68 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import argparse
4
import dataclasses
5
import json
6
from dataclasses import dataclass
7
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
8
                    Tuple, Type, Union, cast, get_args)
9

10
11
import torch

12
import vllm.envs as envs
13
from vllm import version
14
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
15
16
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
17
18
19
20
                         ModelConfig, ModelImpl, ObservabilityConfig,
                         ParallelConfig, PoolerConfig, PromptAdapterConfig,
                         SchedulerConfig, SpeculativeConfig, TaskOption,
                         TokenizerPoolConfig, VllmConfig)
21
from vllm.executor.executor_base import ExecutorBase
22
from vllm.logger import init_logger
23
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
24
from vllm.plugins import load_general_plugins
25
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
26
from vllm.transformers_utils.utils import check_gguf_file
27
from vllm.usage.usage_lib import UsageContext
28
from vllm.utils import FlexibleArgumentParser, StoreBoolean
29

30
if TYPE_CHECKING:
31
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
32

33
34
logger = init_logger(__name__)

35
36
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

37
38
39
40
41
42
43
44
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
45
    "hpu",
46
47
]

48

49
50
51
52
53
54
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


55
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
56
57
58
59
60
61
62
63
64
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
65
66
67
68
69
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
70
71
72
73
74
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
75
76

        try:
77
            parsed_value = int(value)
78
79
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
80
81
82
83
84
85
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
86
87
88
89

    return out_dict


90
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
91
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
92
    """Arguments for vLLM engine."""
93
    model: str = 'facebook/opt-125m'
94
    served_model_name: Optional[Union[str, List[str]]] = None
95
    tokenizer: Optional[str] = None
96
    hf_config_path: Optional[str] = None
97
    task: TaskOption = "auto"
98
    skip_tokenizer_init: bool = False
99
    tokenizer_mode: str = 'auto'
100
    trust_remote_code: bool = False
101
    allowed_local_media_path: str = ""
102
    download_dir: Optional[str] = None
103
    load_format: str = 'auto'
104
    config_format: ConfigFormat = ConfigFormat.AUTO
105
    dtype: str = 'auto'
106
    kv_cache_dtype: str = 'auto'
107
    seed: int = 0
108
    max_model_len: Optional[int] = None
109
110
111
112
113
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
114
    # number of P/D disaggregation (or other disaggregation) workers
115
116
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
117
    max_parallel_loading_workers: Optional[int] = None
118
    block_size: Optional[int] = None
119
    enable_prefix_caching: Optional[bool] = None
120
    disable_sliding_window: bool = False
121
    use_v2_block_manager: bool = True
122
123
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
124
    gpu_memory_utilization: float = 0.90
125
    max_num_batched_tokens: Optional[int] = None
126
127
128
    max_num_partial_prefills: Optional[int] = 1
    max_long_partial_prefills: Optional[int] = 1
    long_prefill_token_threshold: Optional[int] = 0
129
    max_num_seqs: Optional[int] = None
130
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
131
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
132
    revision: Optional[str] = None
133
    code_revision: Optional[str] = None
134
    rope_scaling: Optional[Dict[str, Any]] = None
135
    rope_theta: Optional[float] = None
136
    hf_overrides: Optional[HfOverrides] = None
137
    tokenizer_revision: Optional[str] = None
138
    quantization: Optional[str] = None
139
    enforce_eager: Optional[bool] = None
140
    max_seq_len_to_capture: int = 8192
141
    disable_custom_all_reduce: bool = False
142
    tokenizer_pool_size: int = 0
143
144
145
146
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
147
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
148
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
149
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
150
    disable_mm_preprocessor_cache: bool = False
151
    enable_lora: bool = False
152
    enable_lora_bias: bool = False
153
154
    max_loras: int = 1
    max_lora_rank: int = 16
155
156
157
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
158
    fully_sharded_loras: bool = False
159
    lora_extra_vocab_size: int = 256
160
    long_lora_scaling_factors: Optional[Tuple[float]] = None
161
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
162
    max_cpu_loras: Optional[int] = None
163
    device: str = 'auto'
164
    num_scheduler_steps: int = 1
165
    multi_step_stream_outputs: bool = True
166
    ray_workers_use_nsight: bool = False
167
    num_gpu_blocks_override: Optional[int] = None
168
    num_lookahead_slots: int = 0
169
    model_loader_extra_config: Optional[dict] = None
170
    ignore_patterns: Optional[Union[str, List[str]]] = None
171
    preemption_mode: Optional[str] = None
172

173
    scheduler_delay_factor: float = 0.0
174
    enable_chunked_prefill: Optional[bool] = None
175

176
    guided_decoding_backend: str = 'xgrammar'
177
    logits_processor_pattern: Optional[str] = None
178
179
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
180
    speculative_model_quantization: Optional[str] = None
181
    speculative_draft_tensor_parallel_size: Optional[int] = None
182
    num_speculative_tokens: Optional[int] = None
183
    speculative_disable_mqa_scorer: Optional[bool] = False
184
    speculative_max_model_len: Optional[int] = None
185
    speculative_disable_by_batch_size: Optional[int] = None
186
187
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
188
189
190
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
191
    qlora_adapter_name_or_path: Optional[str] = None
192
    disable_logprobs_during_spec_decoding: Optional[bool] = None
193

194
    show_hidden_metrics_for_version: Optional[str] = None
195
    otlp_traces_endpoint: Optional[str] = None
196
    collect_detailed_traces: Optional[str] = None
197
    disable_async_output_proc: bool = False
198
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
199
    scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
200

201
202
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
203
    compilation_config: Optional[CompilationConfig] = None
204
    worker_cls: str = "auto"
205

206
207
    kv_transfer_config: Optional[KVTransferConfig] = None

208
    generation_config: Optional[str] = None
209
    override_generation_config: Optional[Dict[str, Any]] = None
210
    enable_sleep_mode: bool = False
211
    model_impl: str = "auto"
212

213
214
    calculate_kv_scales: Optional[bool] = None

215
    additional_config: Optional[Dict[str, Any]] = None
216
217
    enable_reasoning: Optional[bool] = None
    reasoning_parser: Optional[str] = None
218

219
    def __post_init__(self):
220
        if not self.tokenizer:
221
            self.tokenizer = self.model
222

223
224
225
226
        # Override the default value of enable_prefix_caching if it's not set
        # by user.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
227

228
229
230
        # Override max_num_seqs if it's not set by user.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
231

232
233
234
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
235
        if isinstance(self.compilation_config, (int, dict)):
236
237
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
238

239
        # Setup plugins
240
241
        from vllm.plugins import load_general_plugins
        load_general_plugins()
242
243

    @staticmethod
244
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
245
        """Shared CLI arguments for vLLM engine."""
246

247
        # Model arguments
248
249
250
        parser.add_argument(
            '--model',
            type=str,
251
            default=EngineArgs.model,
252
            help='Name or path of the huggingface model to use.')
253
254
255
256
257
258
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
259
            'multiple tasks. When the model only supports one task, ``"auto"`` '
260
261
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
262
263
        parser.add_argument(
            '--tokenizer',
264
            type=nullable_str,
265
            default=EngineArgs.tokenizer,
266
267
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
268
269
270
271
272
273
        parser.add_argument(
            "--hf-config-path",
            type=nullable_str,
            default=EngineArgs.hf_config_path,
            help='Name or path of the huggingface config to use. '
            'If unspecified, model name or path will be used.')
274
275
276
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
277
            help='Skip initialization of tokenizer and detokenizer.')
Jasmond L's avatar
Jasmond L committed
278
279
        parser.add_argument(
            '--revision',
280
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
281
            default=None,
282
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
283
284
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
285
286
        parser.add_argument(
            '--code-revision',
287
            type=nullable_str,
288
            default=None,
289
            help='The specific revision to use for the model code on '
290
291
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
292
293
        parser.add_argument(
            '--tokenizer-revision',
294
            type=nullable_str,
295
            default=None,
296
297
298
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
299
300
301
302
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
303
            choices=['auto', 'slow', 'mistral', 'custom'],
304
305
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
306
            'always use the slow tokenizer. \n* '
307
308
309
            '"mistral" will always use the `mistral_common` tokenizer. \n* '
            '"custom" will use --tokenizer to select the '
            'preregistered tokenizer.')
310
311
        parser.add_argument('--trust-remote-code',
                            action='store_true',
312
                            help='Trust remote code from huggingface.')
313
314
315
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
316
317
318
319
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
320
        parser.add_argument('--download-dir',
321
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
322
                            default=EngineArgs.download_dir,
323
                            help='Directory to download and load the weights, '
324
                            'default to the default cache dir of '
325
                            'huggingface.')
326
327
328
329
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
330
            choices=[f.value for f in LoadFormat],
331
332
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
333
            'and fall back to the pytorch bin format if safetensors format '
334
335
336
337
338
339
340
341
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
342
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
343
            'section for more information.\n'
344
345
            '* "runai_streamer" will load the Safetensors weights using Run:ai'
            'Model Streamer \n'
346
347
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
348
349
350
351
352
353
354
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
355
356
357
358
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
359
360
361
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
362
363
364
365
366
367
368
369
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
370
371
372
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
373
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
374
            default=EngineArgs.kv_cache_dtype,
375
            help='Data type for kv cache storage. If "auto", will use model '
376
377
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
378
379
        parser.add_argument('--max-model-len',
                            type=int,
380
                            default=EngineArgs.max_model_len,
381
382
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
383
384
385
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
386
            default='xgrammar',
387
            help='Which engine will be used for guided decoding'
388
            ' (JSON schema / regex etc) by default. Currently support '
389
            'https://github.com/outlines-dev/outlines, '
390
            'https://github.com/mlc-ai/xgrammar, and '
391
392
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
393
            ' parameter.\n'
394
            'Backend-specific options can be supplied in a comma-separated '
395
396
            'list following a colon after the backend name. Valid backends and '
            'all available options are: [xgrammar:no-fallback, '
397
            'xgrammar:disable-any-whitespace, '
398
            'outlines:no-fallback, lm-format-enforcer:no-fallback]')
399
400
401
402
403
404
405
406
        parser.add_argument(
            '--logits-processor-pattern',
            type=nullable_str,
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
407
408
409
410
411
412
413
414
415
416
417
418
        parser.add_argument(
            '--model-impl',
            type=str,
            default=EngineArgs.model_impl,
            choices=[f.value for f in ModelImpl],
            help='Which implementation of the model to use.\n\n'
            '* "auto" will try to use the vLLM implementation if it exists '
            'and fall back to the Transformers implementation if no vLLM '
            'implementation is available.\n'
            '* "vllm" will use the vLLM model implementation.\n'
            '* "transformers" will use the Transformers model '
            'implementation.\n')
419
        # Parallel arguments
420
421
        parser.add_argument(
            '--distributed-executor-backend',
422
            choices=['ray', 'mp', 'uni', 'external_launcher'],
423
            default=EngineArgs.distributed_executor_backend,
424
425
426
427
428
429
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
430
            'only supports Ray for distributed inference.')
431

432
433
434
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
435
                            default=EngineArgs.pipeline_parallel_size,
436
                            help='Number of pipeline stages.')
437
438
439
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
440
                            default=EngineArgs.tensor_parallel_size,
441
                            help='Number of tensor parallel replicas.')
442
443
444
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
445
            default=EngineArgs.max_parallel_loading_workers,
446
            help='Load model sequentially in multiple batches, '
447
            'to avoid RAM OOM when using tensor '
448
            'parallel and large models.')
449
450
451
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
452
            help='If specified, use nsight to profile Ray workers.')
453
        # KV cache arguments
454
455
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
456
                            default=EngineArgs.block_size,
457
                            choices=[8, 16, 32, 64, 128],
458
                            help='Token block size for contiguous chunks of '
459
                            'tokens. This is ignored on neuron devices and '
460
                            'set to ``--max-model-len``. On CUDA devices, '
461
462
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
463

464
465
466
467
468
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
469
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
470
        )
471
472
473
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
474
                            'capping to sliding window size.')
475
476
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
477
                            default=True,
478
479
480
481
482
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
483
484
485
486
487
488
489
490
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
491

492
493
494
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
495
                            help='Random seed for operations.')
496
        parser.add_argument('--swap-space',
497
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
498
                            default=EngineArgs.swap_space,
499
                            help='CPU swap space size (GiB) per GPU.')
500
501
502
503
504
505
506
507
508
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
509
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
510
            'which requires at least 26GB GPU memory. Note that this '
511
            'requires fast CPU-GPU interconnect, as part of the model is '
512
513
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
514
515
516
517
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
518
519
520
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
521
522
523
524
525
526
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
527
        parser.add_argument(
528
            '--num-gpu-blocks-override',
529
530
531
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
532
            ' of GPU blocks. Used for testing preemption.')
533
534
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
535
                            default=EngineArgs.max_num_batched_tokens,
536
537
                            help='Maximum number of batched tokens per '
                            'iteration.')
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        parser.add_argument(
            "--max-num-partial-prefills",
            type=int,
            default=EngineArgs.max_num_partial_prefills,
            help="For chunked prefill, the max number of concurrent \
            partial prefills."
            "Defaults to 1",
        )
        parser.add_argument(
            "--max-long-partial-prefills",
            type=int,
            default=EngineArgs.max_long_partial_prefills,
            help="For chunked prefill, the maximum number of prompts longer "
            "than --long-prefill-token-threshold that will be prefilled "
            "concurrently. Setting this less than --max-num-partial-prefills "
            "will allow shorter prompts to jump the queue in front of longer "
            "prompts in some cases, improving latency. Defaults to 1.")
        parser.add_argument(
            "--long-prefill-token-threshold",
            type=float,
            default=EngineArgs.long_prefill_token_threshold,
            help="For chunked prefill, a request is considered long if the "
            "prompt is longer than this number of tokens. Defaults to 4%% of "
            "the model's context length.",
        )
563
564
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
565
                            default=EngineArgs.max_num_seqs,
566
                            help='Maximum number of sequences per iteration.')
567
568
569
570
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
571
572
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
573
574
        parser.add_argument('--disable-log-stats',
                            action='store_true',
575
                            help='Disable logging statistics.')
576
577
578
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
579
                            type=nullable_str,
580
                            choices=[*QUANTIZATION_METHODS, None],
581
                            default=EngineArgs.quantization,
582
583
584
585
586
587
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
588
589
590
591
592
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
593
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
594
595
596
597
598
599
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
600
601
602
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
603
                            help='Extra arguments for the HuggingFace config. '
604
605
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
606
607
608
609
610
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
611
        parser.add_argument('--max-seq-len-to-capture',
612
613
614
615
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
616
617
618
619
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
620
621
622
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
623
                            help='See ParallelConfig.')
624
625
626
627
628
629
630
631
632
633
634
635
636
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
637
                            type=nullable_str,
638
639
640
641
642
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
643
644
645
646
647
648
649
650
651
652
653
654
655
656

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
657
658
659
660
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
661
            help=('Overrides for the multimodal input mapping/processing, '
662
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
663
        parser.add_argument(
664
            '--disable-mm-preprocessor-cache',
665
            action='store_true',
666
667
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
668

669
670
671
672
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
673
674
675
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
695
            choices=['auto', 'float16', 'bfloat16'],
696
697
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
698
699
700
701
702
703
704
705
706
707
708
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
709
710
711
712
713
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
714
715
                  'Must be >= than max_loras. '
                  'Defaults to max_loras.'))
716
717
718
719
720
721
722
723
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
724
725
726
727
728
729
730
731
732
733
734
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
735
736
737
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
738
                            choices=DEVICE_OPTIONS,
739
                            help='Device type for vLLM execution.')
740
741
742
743
744
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
745

746
747
        parser.add_argument(
            '--multi-step-stream-outputs',
748
749
750
751
752
753
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
754
755
756
757
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
758
            help='Apply a delay (of delay factor multiplied by previous '
759
            'prompt latency) before scheduling next prompt.')
760
761
        parser.add_argument(
            '--enable-chunked-prefill',
762
763
764
765
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
766
            help='If set, the prefill requests can be chunked based on the '
767
            'max_num_batched_tokens.')
768
769
770

        parser.add_argument(
            '--speculative-model',
771
            type=nullable_str,
772
            default=EngineArgs.speculative_model,
773
774
            help=
            'The name of the draft model to be used in speculative decoding.')
775
776
777
778
779
780
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
781
            help='Method used to quantize the weights of speculative model. '
782
783
784
785
786
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
787
788
789
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
790
            default=EngineArgs.num_speculative_tokens,
791
            help='The number of speculative tokens to sample from '
792
            'the draft model in speculative decoding.')
793
794
795
796
797
798
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
799
800
801
802
803
804
805
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
806

807
808
        parser.add_argument(
            '--speculative-max-model-len',
809
            type=int,
810
811
812
813
814
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

815
816
817
818
819
820
821
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

822
823
824
825
826
827
828
829
830
831
832
833
834
835
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

868
869
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
870
            action=StoreBoolean,
871
            default=EngineArgs.disable_logprobs_during_spec_decoding,
872
873
            nargs="?",
            const="True",
874
875
876
877
878
879
880
881
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

882
        parser.add_argument('--model-loader-extra-config',
883
                            type=nullable_str,
884
885
886
887
888
889
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
890
891
892
893
894
895
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
896
            "Default to `original/**/*` to avoid repeated loading of llama's "
897
            "checkpoints.")
898
        parser.add_argument(
899
            '--preemption-mode',
900
901
            type=str,
            default=None,
902
903
904
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
905

906
907
908
909
910
911
912
913
914
915
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
916
            "same as the ``--model`` argument. Noted that this name(s) "
917
            "will also be used in `model_name` tag content of "
918
            "prometheus metrics, if multiple names provided, metrics "
919
            "tag will take the first one.")
920
921
922
923
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
924

925
926
927
928
929
930
931
932
933
934
935
936
        parser.add_argument('--show-hidden-metrics-for-version',
                            type=str,
                            default=None,
                            help='Enable deprecated Prometheus metrics that '
                            'have been hidden since the specified version. '
                            'For example, if a previously deprecated metric '
                            'has been hidden since the v0.7.0 release, you '
                            'use --show-hidden-metrics-for-version=0.7 as a '
                            'temporary escape hatch while you migrate to new '
                            'metrics. The metric is likely to be removed '
                            'completely in an upcoming release.')

937
938
939
940
941
        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
942
943
944
945
946
947
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
948
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
949
950
951
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
952

953
954
955
956
957
958
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
959

960
961
962
963
964
965
966
967
968
969
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

970
971
972
973
974
975
976
        parser.add_argument(
            '--scheduler-cls',
            default=EngineArgs.scheduler_cls,
            help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
            'is the default scheduler. Can be a class directly or the path to '
            'a class of form "mod.custom_class".')

977
        parser.add_argument(
978
979
            '--override-neuron-config',
            type=json.loads,
980
            default=None,
981
            help="Override or set neuron device configuration. "
982
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
983
        parser.add_argument(
984
985
            '--override-pooler-config',
            type=PoolerConfig.from_json,
986
            default=None,
987
            help="Override or set the pooling method for pooling models. "
988
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
989

990
991
992
993
994
995
996
997
998
999
1000
1001
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
1002
1003
1004
1005
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
1006

1007
1008
1009
1010
1011
1012
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

1013
1014
1015
1016
1017
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
1018
1019
1020
1021
1022
        parser.add_argument(
            "--generation-config",
            type=nullable_str,
            default=None,
            help="The folder path to the generation config. "
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
            "Defaults to None, no generation config is loaded, vLLM defaults "
            "will be used. If set to 'auto', the generation config will be "
            "loaded from model path. If set to a folder path, the generation "
            "config will be loaded from the specified folder path. If "
            "`max_new_tokens` is specified in generation config, then "
            "it sets a server-wide limit on the number of output tokens "
            "for all requests.")

        parser.add_argument(
            "--override-generation-config",
            type=json.loads,
            default=None,
            help="Overrides or sets generation config in JSON format. "
            "e.g. ``{\"temperature\": 0.5}``. If used with "
            "--generation-config=auto, the override parameters will be merged "
            "with the default config from the model. If generation-config is "
            "None, only the override parameters are used.")
1040

1041
1042
1043
1044
1045
1046
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

1047
1048
1049
1050
1051
1052
1053
1054
1055
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')

1056
1057
1058
1059
1060
1061
1062
1063
        parser.add_argument(
            "--additional-config",
            type=json.loads,
            default=None,
            help="Additional config for specified platform in JSON format. "
            "Different platforms may support different configs. Make sure the "
            "configs are valid for the platform you are using. The input format"
            " is like '{\"config_key\":\"config_value\"}'")
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082

        parser.add_argument(
            "--enable-reasoning",
            action="store_true",
            default=False,
            help="Whether to enable reasoning_content for the model. "
            "If enabled, the model will be able to generate reasoning content."
        )

        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=["deepseek_r1"],
            default=None,
            help=
            "Select the reasoning parser depending on the model that you're "
            "using. This is used to parse the reasoning content into OpenAI "
            "API format. Required for ``--enable-reasoning``.")

1083
        return parser
1084
1085

    @classmethod
1086
    def from_cli_args(cls, args: argparse.Namespace):
1087
1088
1089
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
1090
1091
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
1092

1093
    def create_model_config(self) -> ModelConfig:
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

1105
        return ModelConfig(
1106
            model=self.model,
1107
            hf_config_path=self.hf_config_path,
1108
            task=self.task,
1109
1110
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
1111
1112
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1113
            allowed_local_media_path=self.allowed_local_media_path,
1114
1115
1116
1117
1118
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1119
            rope_theta=self.rope_theta,
1120
            hf_overrides=self.hf_overrides,
1121
1122
1123
1124
1125
1126
1127
1128
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
1129
            served_model_name=self.served_model_name,
1130
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1131
            use_async_output_proc=not self.disable_async_output_proc,
1132
            config_format=self.config_format,
1133
            mm_processor_kwargs=self.mm_processor_kwargs,
1134
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1135
1136
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1137
            logits_processor_pattern=self.logits_processor_pattern,
1138
            generation_config=self.generation_config,
1139
            override_generation_config=self.override_generation_config,
1140
            enable_sleep_mode=self.enable_sleep_mode,
1141
            model_impl=self.model_impl,
1142
        )
1143

1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    def create_load_config(self) -> LoadConfig:
        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

1161
1162
1163
1164
1165
1166
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
        )
1167

1168
1169
1170
1171
1172
    def create_engine_config(self,
                             usage_context: Optional[UsageContext] = None
                             ) -> VllmConfig:
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update()
1173

1174
1175
        if envs.VLLM_USE_V1:
            self._override_v1_engine_args(usage_context)
1176

1177
        device_config = DeviceConfig(device=self.device)
1178
1179
        model_config = self.create_model_config()

1180
1181
1182
1183
1184
        if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
                and self.enable_prefix_caching):
            logger.warning("--enable-prefix-caching is currently not "
                           "supported for multimodal models in v0 and "
                           "has been disabled.")
1185
1186
            self.enable_prefix_caching = False

1187
        cache_config = CacheConfig(
1188
            block_size=self.block_size,
1189
1190
1191
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1192
            is_attention_free=model_config.is_attention_free,
1193
1194
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1195
1196
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
1197
            calculate_kv_scales=self.calculate_kv_scales,
1198
        )
1199
        parallel_config = ParallelConfig(
1200
1201
1202
1203
1204
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1205
1206
1207
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1208
            ),
1209
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1210
1211
1212
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
        )
1213

1214
1215
1216
1217
1218
1219
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
1220

1221
1222
1223
            # For multimodal models and models with MLA, chunked prefill is
            # disabled by default in V0, but enabled by design in V1
            if model_config.is_multimodal_model or model_config.use_mla:
1224
1225
1226
                self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)

            elif use_long_context:
1227
1228
1229
1230
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
1231
                from vllm.platforms import current_platform
1232
1233
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
1234
                        and not self.enable_prompt_adapter
1235
1236
                        and model_config.runner_type != "pooling"
                        and not current_platform.is_rocm()):
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
1253
1254
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
1255
            msg = "Chunked prefill is not supported for pooling models"
1256
            raise ValueError(msg)
1257

1258
1259
1260
1261
1262
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
1263
1264
            speculative_model_quantization = \
                self.speculative_model_quantization,
1265
1266
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
1267
            num_speculative_tokens=self.num_speculative_tokens,
1268
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
1269
1270
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
1271
1272
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
1273
            disable_log_stats=self.disable_log_stats,
1274
1275
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1276
1277
1278
1279
1280
1281
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
1282
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
1283
1284
        )

1285
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1286
        # If the feature combo become valid
1287
1288
1289
1290
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1291
1292
1293
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1294
1295
1296
1297
1298
1299
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1300
1301
1302
1303
1304
1305
1306
1307
1308

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1309
        scheduler_config = SchedulerConfig(
1310
            runner_type=model_config.runner_type,
1311
1312
1313
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1314
            num_lookahead_slots=num_lookahead_slots,
1315
1316
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1317
            is_multimodal_model=model_config.is_multimodal_model,
1318
            preemption_mode=self.preemption_mode,
1319
            num_scheduler_steps=self.num_scheduler_steps,
1320
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1321
1322
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1323
            policy=self.scheduling_policy,
1324
            scheduler_cls=self.scheduler_cls,
1325
1326
1327
1328
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
        )
1329

1330
        lora_config = LoRAConfig(
1331
            bias_enabled=self.enable_lora_bias,
1332
1333
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1334
            fully_sharded_loras=self.fully_sharded_loras,
1335
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1336
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1337
1338
1339
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1340

1341
1342
1343
1344
1345
1346
1347
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1348
        load_config = self.create_load_config()
1349

1350
1351
1352
1353
1354
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1355
        decoding_config = DecodingConfig(
1356
1357
1358
1359
            guided_decoding_backend=self.guided_decoding_backend,
            reasoning_backend=self.reasoning_parser
            if self.enable_reasoning else None,
        )
1360

1361
1362
1363
1364
1365
        show_hidden_metrics = False
        if self.show_hidden_metrics_for_version is not None:
            show_hidden_metrics = version._prev_minor_version_was(
                self.show_hidden_metrics_for_version)

1366
1367
1368
1369
1370
1371
1372
1373
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1374
        observability_config = ObservabilityConfig(
1375
            show_hidden_metrics=show_hidden_metrics,
1376
1377
1378
1379
1380
1381
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1382

1383
        config = VllmConfig(
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1394
            prompt_adapter_config=prompt_adapter_config,
1395
            compilation_config=self.compilation_config,
1396
            kv_transfer_config=self.kv_transfer_config,
1397
            additional_config=self.additional_config,
1398
        )
1399

1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
        if envs.VLLM_USE_V1:
            self._override_v1_engine_config(config)
        return config

    def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
        """
        Override the EngineArgs's args based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

1410
1411
1412
1413
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
        # When no user override, set the default values based on the usage
        # context.
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
        # Use different default values for different hardware.
        from vllm.platforms import current_platform
        device_name = current_platform.get_device_name().lower()
        if "h100" in device_name or "h200" in device_name:
            # For H100 and H200, we use larger default values.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }

1430
1431
1432
1433
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1434
1435
1436
            logger.warning(
                "Setting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens, usage_context.value)
1437
1438
1439
1440
1441
1442
1443

    def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
        """
        Override the EngineConfig's configs based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

1444

1445
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1446
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1447
    """Arguments for asynchronous vLLM engine."""
1448
    disable_log_requests: bool = False
1449
1450

    @staticmethod
1451
1452
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1453
1454
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1455
1456
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1457
                            help='Disable logging requests.')
1458
1459
1460
1461
1462
1463
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update(parser)
1464
        return parser
1465
1466
1467
1468


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1469
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1470
1471
1472


def _async_engine_args_parser():
1473
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1474
                                        async_args_only=True)