arg_utils.py 82.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import argparse
4
import dataclasses
5
import json
6
import threading
7
from dataclasses import dataclass
8
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
9
                    Tuple, Type, Union, cast, get_args)
10

11
12
import torch

13
import vllm.envs as envs
14
from vllm import version
15
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
16
17
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
18
19
20
21
                         ModelConfig, ModelImpl, ObservabilityConfig,
                         ParallelConfig, PoolerConfig, PromptAdapterConfig,
                         SchedulerConfig, SpeculativeConfig, TaskOption,
                         TokenizerPoolConfig, VllmConfig)
22
from vllm.executor.executor_base import ExecutorBase
23
from vllm.logger import init_logger
24
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
25
from vllm.plugins import load_general_plugins
26
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
27
from vllm.transformers_utils.utils import check_gguf_file
28
from vllm.usage.usage_lib import UsageContext
29
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
30

31
if TYPE_CHECKING:
32
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
33

34
35
logger = init_logger(__name__)

36
37
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

38
39
40
41
42
43
44
45
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
46
    "hpu",
47
48
]

49

50
51
52
53
54
55
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


56
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
57
58
59
60
61
62
63
64
65
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
66
67
68
69
70
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
71
72
73
74
75
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
76
77

        try:
78
            parsed_value = int(value)
79
80
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
81
82
83
84
85
86
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
87
88
89
90

    return out_dict


91
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
92
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
93
    """Arguments for vLLM engine."""
94
    model: str = 'facebook/opt-125m'
95
    served_model_name: Optional[Union[str, List[str]]] = None
96
    tokenizer: Optional[str] = None
97
    hf_config_path: Optional[str] = None
98
    task: TaskOption = "auto"
99
    skip_tokenizer_init: bool = False
100
    tokenizer_mode: str = 'auto'
101
    trust_remote_code: bool = False
102
    allowed_local_media_path: str = ""
103
    download_dir: Optional[str] = None
104
    load_format: str = 'auto'
105
    config_format: ConfigFormat = ConfigFormat.AUTO
106
    dtype: str = 'auto'
107
    kv_cache_dtype: str = 'auto'
108
    seed: Optional[int] = None
109
    max_model_len: Optional[int] = None
110
111
112
113
114
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
115
    # number of P/D disaggregation (or other disaggregation) workers
116
117
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
118
    enable_expert_parallel: bool = False
119
    max_parallel_loading_workers: Optional[int] = None
120
    block_size: Optional[int] = None
121
    enable_prefix_caching: Optional[bool] = None
122
    disable_sliding_window: bool = False
123
    disable_cascade_attn: bool = False
124
    use_v2_block_manager: bool = True
125
126
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
127
    gpu_memory_utilization: float = 0.90
128
    max_num_batched_tokens: Optional[int] = None
129
130
131
    max_num_partial_prefills: Optional[int] = 1
    max_long_partial_prefills: Optional[int] = 1
    long_prefill_token_threshold: Optional[int] = 0
132
    max_num_seqs: Optional[int] = None
133
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
134
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
135
    revision: Optional[str] = None
136
    code_revision: Optional[str] = None
137
    rope_scaling: Optional[Dict[str, Any]] = None
138
    rope_theta: Optional[float] = None
139
    hf_overrides: Optional[HfOverrides] = None
140
    tokenizer_revision: Optional[str] = None
141
    quantization: Optional[str] = None
142
    enforce_eager: Optional[bool] = None
143
    max_seq_len_to_capture: int = 8192
144
    disable_custom_all_reduce: bool = False
145
    tokenizer_pool_size: int = 0
146
147
148
149
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
150
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
151
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
152
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
153
    disable_mm_preprocessor_cache: bool = False
154
    enable_lora: bool = False
155
    enable_lora_bias: bool = False
156
157
    max_loras: int = 1
    max_lora_rank: int = 16
158
159
160
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
161
    fully_sharded_loras: bool = False
162
    lora_extra_vocab_size: int = 256
163
    long_lora_scaling_factors: Optional[Tuple[float]] = None
164
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
165
    max_cpu_loras: Optional[int] = None
166
    device: str = 'auto'
167
    num_scheduler_steps: int = 1
168
    multi_step_stream_outputs: bool = True
169
    ray_workers_use_nsight: bool = False
170
    num_gpu_blocks_override: Optional[int] = None
171
    num_lookahead_slots: int = 0
172
    model_loader_extra_config: Optional[dict] = None
173
    ignore_patterns: Optional[Union[str, List[str]]] = None
174
    preemption_mode: Optional[str] = None
175

176
    scheduler_delay_factor: float = 0.0
177
    enable_chunked_prefill: Optional[bool] = None
178

179
    guided_decoding_backend: str = 'xgrammar'
180
    logits_processor_pattern: Optional[str] = None
181
182
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
183
    speculative_model_quantization: Optional[str] = None
184
    speculative_draft_tensor_parallel_size: Optional[int] = None
185
    num_speculative_tokens: Optional[int] = None
186
    speculative_disable_mqa_scorer: Optional[bool] = False
187
    speculative_max_model_len: Optional[int] = None
188
    speculative_disable_by_batch_size: Optional[int] = None
189
190
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
191
192
193
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
194
    qlora_adapter_name_or_path: Optional[str] = None
195
    disable_logprobs_during_spec_decoding: Optional[bool] = None
196

197
    show_hidden_metrics_for_version: Optional[str] = None
198
    otlp_traces_endpoint: Optional[str] = None
199
    collect_detailed_traces: Optional[str] = None
200
    disable_async_output_proc: bool = False
201
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
202
    scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
203

204
205
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
206
    compilation_config: Optional[CompilationConfig] = None
207
    worker_cls: str = "auto"
208
    worker_extension_cls: str = ""
209

210
211
    kv_transfer_config: Optional[KVTransferConfig] = None

212
    generation_config: Optional[str] = "auto"
213
    override_generation_config: Optional[Dict[str, Any]] = None
214
    enable_sleep_mode: bool = False
215
    model_impl: str = "auto"
216

217
218
    calculate_kv_scales: Optional[bool] = None

219
    additional_config: Optional[Dict[str, Any]] = None
220
221
    enable_reasoning: Optional[bool] = None
    reasoning_parser: Optional[str] = None
222
    use_tqdm_on_load: bool = True
223

224
    def __post_init__(self):
225
        if not self.tokenizer:
226
            self.tokenizer = self.model
227

228
229
230
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
231
        if isinstance(self.compilation_config, (int, dict)):
232
233
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
234

235
        # Setup plugins
236
237
        from vllm.plugins import load_general_plugins
        load_general_plugins()
238
239

    @staticmethod
240
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
241
        """Shared CLI arguments for vLLM engine."""
242
        # Model arguments
243
244
245
        parser.add_argument(
            '--model',
            type=str,
246
            default=EngineArgs.model,
247
            help='Name or path of the huggingface model to use.')
248
249
250
251
252
253
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
254
            'multiple tasks. When the model only supports one task, ``"auto"`` '
255
256
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
257
258
        parser.add_argument(
            '--tokenizer',
259
            type=nullable_str,
260
            default=EngineArgs.tokenizer,
261
262
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
263
264
265
266
267
268
        parser.add_argument(
            "--hf-config-path",
            type=nullable_str,
            default=EngineArgs.hf_config_path,
            help='Name or path of the huggingface config to use. '
            'If unspecified, model name or path will be used.')
269
270
271
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
272
273
274
            help='Skip initialization of tokenizer and detokenizer. '
            'Expects valid prompt_token_ids and None for prompt from '
            'the input. The generated output will contain token ids.')
Jasmond L's avatar
Jasmond L committed
275
276
        parser.add_argument(
            '--revision',
277
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
278
            default=None,
279
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
280
281
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
282
283
        parser.add_argument(
            '--code-revision',
284
            type=nullable_str,
285
            default=None,
286
            help='The specific revision to use for the model code on '
287
288
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
289
290
        parser.add_argument(
            '--tokenizer-revision',
291
            type=nullable_str,
292
            default=None,
293
294
295
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
296
297
298
299
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
300
            choices=['auto', 'slow', 'mistral', 'custom'],
301
302
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
303
            'always use the slow tokenizer. \n* '
304
305
306
            '"mistral" will always use the `mistral_common` tokenizer. \n* '
            '"custom" will use --tokenizer to select the '
            'preregistered tokenizer.')
307
308
        parser.add_argument('--trust-remote-code',
                            action='store_true',
309
                            help='Trust remote code from huggingface.')
310
311
312
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
313
314
315
316
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
317
        parser.add_argument('--download-dir',
318
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
319
                            default=EngineArgs.download_dir,
320
                            help='Directory to download and load the weights, '
321
                            'default to the default cache dir of '
322
                            'huggingface.')
323
324
325
326
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
327
            choices=[f.value for f in LoadFormat],
328
329
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
330
            'and fall back to the pytorch bin format if safetensors format '
331
332
333
334
335
336
337
338
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
339
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
340
            'section for more information.\n'
341
            '* "runai_streamer" will load the Safetensors weights using Run:ai'
342
            'Model Streamer.\n'
343
            '* "bitsandbytes" will load the weights using bitsandbytes '
344
345
346
347
348
349
350
            'quantization.\n'
            '* "sharded_state" will load weights from pre-sharded checkpoint '
            'files, supporting efficient loading of tensor-parallel models\n'
            '* "gguf" will load weights from GGUF format files (details '
            'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n'
            '* "mistral" will load weights from consolidated safetensors files '
            'used by Mistral models.\n')
351
352
353
354
355
356
357
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
358
359
360
361
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
362
363
364
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
365
366
367
368
369
370
371
372
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
373
374
375
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
376
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
377
            default=EngineArgs.kv_cache_dtype,
378
            help='Data type for kv cache storage. If "auto", will use model '
379
380
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
381
382
        parser.add_argument('--max-model-len',
                            type=int,
383
                            default=EngineArgs.max_model_len,
384
385
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
386
387
388
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
389
            default='xgrammar',
390
            help='Which engine will be used for guided decoding'
391
            ' (JSON schema / regex etc) by default. Currently support '
392
            'https://github.com/outlines-dev/outlines, '
393
            'https://github.com/mlc-ai/xgrammar, and '
394
395
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
396
            ' parameter.\n'
397
            'Backend-specific options can be supplied in a comma-separated '
398
399
            'list following a colon after the backend name. Valid backends and '
            'all available options are: [xgrammar:no-fallback, '
400
            'xgrammar:disable-any-whitespace, '
401
            'outlines:no-fallback, lm-format-enforcer:no-fallback]')
402
403
404
405
406
407
408
409
        parser.add_argument(
            '--logits-processor-pattern',
            type=nullable_str,
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
410
411
412
413
414
415
416
417
418
419
420
421
        parser.add_argument(
            '--model-impl',
            type=str,
            default=EngineArgs.model_impl,
            choices=[f.value for f in ModelImpl],
            help='Which implementation of the model to use.\n\n'
            '* "auto" will try to use the vLLM implementation if it exists '
            'and fall back to the Transformers implementation if no vLLM '
            'implementation is available.\n'
            '* "vllm" will use the vLLM model implementation.\n'
            '* "transformers" will use the Transformers model '
            'implementation.\n')
422
        # Parallel arguments
423
424
        parser.add_argument(
            '--distributed-executor-backend',
425
            choices=['ray', 'mp', 'uni', 'external_launcher'],
426
            default=EngineArgs.distributed_executor_backend,
427
428
429
430
431
432
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
433
            'only supports Ray for distributed inference.')
434

435
436
437
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
438
                            default=EngineArgs.pipeline_parallel_size,
439
                            help='Number of pipeline stages.')
440
441
442
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
443
                            default=EngineArgs.tensor_parallel_size,
444
                            help='Number of tensor parallel replicas.')
445
446
447
448
449
        parser.add_argument(
            '--enable-expert-parallel',
            action='store_true',
            help='Use expert parallelism instead of tensor parallelism '
            'for MoE layers.')
450
451
452
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
453
            default=EngineArgs.max_parallel_loading_workers,
454
            help='Load model sequentially in multiple batches, '
455
            'to avoid RAM OOM when using tensor '
456
            'parallel and large models.')
457
458
459
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
460
            help='If specified, use nsight to profile Ray workers.')
461
        # KV cache arguments
462
463
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
464
                            default=EngineArgs.block_size,
465
                            choices=[8, 16, 32, 64, 128],
466
                            help='Token block size for contiguous chunks of '
467
                            'tokens. This is ignored on neuron devices and '
468
                            'set to ``--max-model-len``. On CUDA devices, '
469
470
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
471

472
473
474
475
476
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
477
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
478
        )
479
480
481
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
482
                            'capping to sliding window size.')
483
484
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
485
                            default=True,
486
487
488
489
490
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
491
492
493
494
495
496
497
498
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
499

500
501
502
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
503
                            help='Random seed for operations.')
504
        parser.add_argument('--swap-space',
505
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
506
                            default=EngineArgs.swap_space,
507
                            help='CPU swap space size (GiB) per GPU.')
508
509
510
511
512
513
514
515
516
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
517
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
518
            'which requires at least 26GB GPU memory. Note that this '
519
            'requires fast CPU-GPU interconnect, as part of the model is '
520
521
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
522
523
524
525
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
526
527
528
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
529
530
531
532
533
534
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
535
        parser.add_argument(
536
            '--num-gpu-blocks-override',
537
538
539
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
540
            ' of GPU blocks. Used for testing preemption.')
541
542
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
543
                            default=EngineArgs.max_num_batched_tokens,
544
545
                            help='Maximum number of batched tokens per '
                            'iteration.')
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
        parser.add_argument(
            "--max-num-partial-prefills",
            type=int,
            default=EngineArgs.max_num_partial_prefills,
            help="For chunked prefill, the max number of concurrent \
            partial prefills."
            "Defaults to 1",
        )
        parser.add_argument(
            "--max-long-partial-prefills",
            type=int,
            default=EngineArgs.max_long_partial_prefills,
            help="For chunked prefill, the maximum number of prompts longer "
            "than --long-prefill-token-threshold that will be prefilled "
            "concurrently. Setting this less than --max-num-partial-prefills "
            "will allow shorter prompts to jump the queue in front of longer "
            "prompts in some cases, improving latency. Defaults to 1.")
        parser.add_argument(
            "--long-prefill-token-threshold",
            type=float,
            default=EngineArgs.long_prefill_token_threshold,
            help="For chunked prefill, a request is considered long if the "
            "prompt is longer than this number of tokens. Defaults to 4%% of "
            "the model's context length.",
        )
571
572
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
573
                            default=EngineArgs.max_num_seqs,
574
                            help='Maximum number of sequences per iteration.')
575
576
577
578
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
579
580
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
581
582
        parser.add_argument('--disable-log-stats',
                            action='store_true',
583
                            help='Disable logging statistics.')
584
585
586
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
587
                            type=nullable_str,
588
                            choices=[*QUANTIZATION_METHODS, None],
589
                            default=EngineArgs.quantization,
590
591
592
593
594
595
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
596
597
598
599
600
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
601
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
602
603
604
605
606
607
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
608
609
610
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
611
                            help='Extra arguments for the HuggingFace config. '
612
613
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
614
615
616
617
618
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
619
        parser.add_argument('--max-seq-len-to-capture',
620
621
622
623
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
624
625
626
627
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
628
629
630
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
631
                            help='See ParallelConfig.')
632
633
634
635
636
637
638
639
640
641
642
643
644
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
645
                            type=nullable_str,
646
647
648
649
650
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
651
652
653
654
655
656
657
658
659
660
661
662
663
664

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
665
666
667
668
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
669
            help=('Overrides for the multimodal input mapping/processing, '
670
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
671
        parser.add_argument(
672
            '--disable-mm-preprocessor-cache',
673
            action='store_true',
674
675
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
676

677
678
679
680
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
681
682
683
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
703
            choices=['auto', 'float16', 'bfloat16'],
704
705
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
706
707
708
709
710
711
712
713
714
715
716
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
717
718
719
720
721
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
722
723
                  'Must be >= than max_loras. '
                  'Defaults to max_loras.'))
724
725
726
727
728
729
730
731
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
732
733
734
735
736
737
738
739
740
741
742
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
743
744
745
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
746
                            choices=DEVICE_OPTIONS,
747
                            help='Device type for vLLM execution.')
748
749
750
751
752
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
753
754
755
756
757
758
759
760
        parser.add_argument(
            '--use-tqdm-on-load',
            dest='use_tqdm_on_load',
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.use_tqdm_on_load,
            help='Whether to enable/disable progress bar '
            'when loading model weights.',
        )
761

762
763
        parser.add_argument(
            '--multi-step-stream-outputs',
764
765
766
767
768
769
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
770
771
772
773
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
774
            help='Apply a delay (of delay factor multiplied by previous '
775
            'prompt latency) before scheduling next prompt.')
776
777
        parser.add_argument(
            '--enable-chunked-prefill',
778
779
780
781
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
782
            help='If set, the prefill requests can be chunked based on the '
783
            'max_num_batched_tokens.')
784
785
786

        parser.add_argument(
            '--speculative-model',
787
            type=nullable_str,
788
            default=EngineArgs.speculative_model,
789
790
            help=
            'The name of the draft model to be used in speculative decoding.')
791
792
793
794
795
796
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
797
            help='Method used to quantize the weights of speculative model. '
798
799
800
801
802
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
803
804
805
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
806
            default=EngineArgs.num_speculative_tokens,
807
            help='The number of speculative tokens to sample from '
808
            'the draft model in speculative decoding.')
809
810
811
812
813
814
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
815
816
817
818
819
820
821
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
822

823
824
        parser.add_argument(
            '--speculative-max-model-len',
825
            type=int,
826
827
828
829
830
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

831
832
833
834
835
836
837
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

838
839
840
841
842
843
844
845
846
847
848
849
850
851
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

884
885
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
886
            action=StoreBoolean,
887
            default=EngineArgs.disable_logprobs_during_spec_decoding,
888
889
            nargs="?",
            const="True",
890
891
892
893
894
895
896
897
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

898
        parser.add_argument('--model-loader-extra-config',
899
                            type=nullable_str,
900
901
902
903
904
905
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
906
907
908
909
910
911
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
912
            "Default to `original/**/*` to avoid repeated loading of llama's "
913
            "checkpoints.")
914
        parser.add_argument(
915
            '--preemption-mode',
916
917
            type=str,
            default=None,
918
919
920
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
921

922
923
924
925
926
927
928
929
930
931
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
932
            "same as the ``--model`` argument. Noted that this name(s) "
933
            "will also be used in `model_name` tag content of "
934
            "prometheus metrics, if multiple names provided, metrics "
935
            "tag will take the first one.")
936
937
938
939
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
940

941
942
943
944
945
946
947
948
949
950
951
952
        parser.add_argument('--show-hidden-metrics-for-version',
                            type=str,
                            default=None,
                            help='Enable deprecated Prometheus metrics that '
                            'have been hidden since the specified version. '
                            'For example, if a previously deprecated metric '
                            'has been hidden since the v0.7.0 release, you '
                            'use --show-hidden-metrics-for-version=0.7 as a '
                            'temporary escape hatch while you migrate to new '
                            'metrics. The metric is likely to be removed '
                            'completely in an upcoming release.')

953
954
955
956
957
        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
958
959
960
961
962
963
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
964
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
965
966
967
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
968

969
970
971
972
973
974
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
975

976
977
978
979
980
981
982
983
984
985
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

986
987
988
989
990
991
992
        parser.add_argument(
            '--scheduler-cls',
            default=EngineArgs.scheduler_cls,
            help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
            'is the default scheduler. Can be a class directly or the path to '
            'a class of form "mod.custom_class".')

993
        parser.add_argument(
994
995
            '--override-neuron-config',
            type=json.loads,
996
            default=None,
997
            help="Override or set neuron device configuration. "
998
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
999
        parser.add_argument(
1000
1001
            '--override-pooler-config',
            type=PoolerConfig.from_json,
1002
            default=None,
1003
            help="Override or set the pooling method for pooling models. "
1004
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
1005

1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
1018
1019
1020
1021
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
1022

1023
1024
1025
1026
1027
1028
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

1029
1030
1031
1032
1033
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
1034
1035
1036
1037
1038
1039
1040
        parser.add_argument(
            '--worker-extension-cls',
            type=str,
            default="",
            help='The worker extension class on top of the worker cls, '
            'it is useful if you just want to add new functions to the worker '
            'class without changing the existing functions.')
1041
1042
1043
        parser.add_argument(
            "--generation-config",
            type=nullable_str,
1044
            default="auto",
1045
            help="The folder path to the generation config. "
1046
1047
1048
1049
1050
            "Defaults to 'auto', the generation config will be loaded from "
            "model path. If set to 'vllm', no generation config is loaded, "
            "vLLM defaults will be used. If set to a folder path, the "
            "generation config will be loaded from the specified folder path. "
            "If `max_new_tokens` is specified in generation config, then "
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
            "it sets a server-wide limit on the number of output tokens "
            "for all requests.")

        parser.add_argument(
            "--override-generation-config",
            type=json.loads,
            default=None,
            help="Overrides or sets generation config in JSON format. "
            "e.g. ``{\"temperature\": 0.5}``. If used with "
            "--generation-config=auto, the override parameters will be merged "
            "with the default config from the model. If generation-config is "
            "None, only the override parameters are used.")
1063

1064
1065
1066
1067
1068
1069
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

1070
1071
1072
1073
1074
1075
1076
1077
1078
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')

1079
1080
1081
1082
1083
1084
1085
1086
        parser.add_argument(
            "--additional-config",
            type=json.loads,
            default=None,
            help="Additional config for specified platform in JSON format. "
            "Different platforms may support different configs. Make sure the "
            "configs are valid for the platform you are using. The input format"
            " is like '{\"config_key\":\"config_value\"}'")
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105

        parser.add_argument(
            "--enable-reasoning",
            action="store_true",
            default=False,
            help="Whether to enable reasoning_content for the model. "
            "If enabled, the model will be able to generate reasoning content."
        )

        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=["deepseek_r1"],
            default=None,
            help=
            "Select the reasoning parser depending on the model that you're "
            "using. This is used to parse the reasoning content into OpenAI "
            "API format. Required for ``--enable-reasoning``.")

1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        parser.add_argument(
            "--disable-cascade-attn",
            action="store_true",
            default=False,
            help="Disable cascade attention for V1. While cascade attention "
            "does not change the mathematical correctness, disabling it "
            "could be useful for preventing potential numerical issues. "
            "Note that even if this is set to False, cascade attention will be "
            "only used when the heuristic tells that it's beneficial.")

1116
        return parser
1117
1118

    @classmethod
1119
    def from_cli_args(cls, args: argparse.Namespace):
1120
1121
1122
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
1123
1124
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
1125

1126
    def create_model_config(self) -> ModelConfig:
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

1138
        return ModelConfig(
1139
            model=self.model,
1140
            hf_config_path=self.hf_config_path,
1141
            task=self.task,
1142
1143
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
1144
1145
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1146
            allowed_local_media_path=self.allowed_local_media_path,
1147
1148
1149
1150
1151
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1152
            rope_theta=self.rope_theta,
1153
            hf_overrides=self.hf_overrides,
1154
1155
1156
1157
1158
1159
1160
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
1161
            disable_cascade_attn=self.disable_cascade_attn,
1162
            skip_tokenizer_init=self.skip_tokenizer_init,
1163
            served_model_name=self.served_model_name,
1164
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1165
            use_async_output_proc=not self.disable_async_output_proc,
1166
            config_format=self.config_format,
1167
            mm_processor_kwargs=self.mm_processor_kwargs,
1168
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1169
1170
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1171
            logits_processor_pattern=self.logits_processor_pattern,
1172
            generation_config=self.generation_config,
1173
            override_generation_config=self.override_generation_config,
1174
            enable_sleep_mode=self.enable_sleep_mode,
1175
            model_impl=self.model_impl,
1176
        )
1177

1178
1179
    def create_load_config(self) -> LoadConfig:

1180
        if(self.qlora_adapter_name_or_path is not None) and \
1181
1182
            self.quantization != "bitsandbytes":
            raise ValueError(
1183
                "QLoRA adapter only support "
1184
1185
                f"'bitsandbytes' quantization, but got {self.quantization}")

1186
1187
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1188
1189
1190
1191
1192
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1193
            use_tqdm_on_load=self.use_tqdm_on_load,
1194
        )
1195

1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1206

1207
1208
1209
1210
1211
1212
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1213
1214
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update()
1215

1216
        device_config = DeviceConfig(device=self.device)
1217
1218
        model_config = self.create_model_config()

1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
            self._set_default_args_v1(usage_context)
        else:
            self._set_default_args_v0(model_config)
1241

1242
        cache_config = CacheConfig(
1243
            block_size=self.block_size,
1244
1245
1246
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1247
            is_attention_free=model_config.is_attention_free,
1248
1249
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1250
1251
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
1252
            calculate_kv_scales=self.calculate_kv_scales,
1253
        )
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1266
        parallel_config = ParallelConfig(
1267
1268
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1269
            enable_expert_parallel=self.enable_expert_parallel,
1270
1271
1272
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1273
1274
1275
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1276
            ),
1277
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1278
            placement_group=placement_group,
1279
1280
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1281
            worker_extension_cls=self.worker_extension_cls,
1282
        )
1283
1284
1285
1286
1287
1288

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
1289
1290
            speculative_model_quantization = \
                self.speculative_model_quantization,
1291
1292
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
1293
            num_speculative_tokens=self.num_speculative_tokens,
1294
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
1295
1296
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
1297
1298
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
1299
            disable_log_stats=self.disable_log_stats,
1300
1301
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1302
1303
1304
1305
1306
1307
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
1308
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
1309
1310
        )

1311
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1312
        # If the feature combo become valid
1313
1314
1315
1316
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1317
1318
1319
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1320
1321
1322
1323
1324
1325
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1326
1327
1328
1329
1330
1331
1332
1333
1334

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1335
        scheduler_config = SchedulerConfig(
1336
            runner_type=model_config.runner_type,
1337
1338
1339
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1340
            num_lookahead_slots=num_lookahead_slots,
1341
1342
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1343
            is_multimodal_model=model_config.is_multimodal_model,
1344
            preemption_mode=self.preemption_mode,
1345
            num_scheduler_steps=self.num_scheduler_steps,
1346
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1347
1348
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1349
            policy=self.scheduling_policy,
1350
            scheduler_cls=self.scheduler_cls,
1351
1352
1353
1354
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
        )
1355

1356
        lora_config = LoRAConfig(
1357
            bias_enabled=self.enable_lora_bias,
1358
1359
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1360
            fully_sharded_loras=self.fully_sharded_loras,
1361
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1362
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1363
1364
1365
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1366

1367
1368
1369
1370
1371
1372
1373
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1374
        load_config = self.create_load_config()
1375

1376
1377
1378
1379
1380
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1381
        decoding_config = DecodingConfig(
1382
1383
1384
1385
            guided_decoding_backend=self.guided_decoding_backend,
            reasoning_backend=self.reasoning_parser
            if self.enable_reasoning else None,
        )
1386

1387
1388
1389
1390
1391
        show_hidden_metrics = False
        if self.show_hidden_metrics_for_version is not None:
            show_hidden_metrics = version._prev_minor_version_was(
                self.show_hidden_metrics_for_version)

1392
1393
1394
1395
1396
1397
1398
1399
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1400
        observability_config = ObservabilityConfig(
1401
            show_hidden_metrics=show_hidden_metrics,
1402
1403
1404
1405
1406
1407
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1408

1409
        config = VllmConfig(
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1420
            prompt_adapter_config=prompt_adapter_config,
1421
            compilation_config=self.compilation_config,
1422
            kv_transfer_config=self.kv_transfer_config,
1423
            additional_config=self.additional_config,
1424
        )
1425

1426
1427
        return config

1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

        if (self.load_format == LoadFormat.TENSORIZER.value
                or self.load_format == LoadFormat.SHARDED_STATE.value):
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

        if self.preemption_mode != EngineArgs.preemption_mode:
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

        if self.scheduling_policy != EngineArgs.scheduling_policy:
            _raise_or_fallback(feature_name="--scheduling-policy",
                               recommend_to_remove=False)
            return False

        if self.worker_cls != EngineArgs.worker_cls:
            _raise_or_fallback(feature_name="--worker-cls",
                               recommend_to_remove=False)
            return False

        if self.worker_extension_cls != EngineArgs.worker_extension_cls:
            _raise_or_fallback(feature_name="--worker-extension-cls",
                               recommend_to_remove=False)
            return False

        if self.num_scheduler_steps != EngineArgs.num_scheduler_steps:
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

        if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor:
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

        if self.additional_config != EngineArgs.additional_config:
            _raise_or_fallback(feature_name="--additional-config",
                               recommend_to_remove=False)
            return False

        # Only support Xgrammar for guided decoding so far.
        SUPPORTED_GUIDED_DECODING = ["xgrammar", "xgrammar:nofallback"]
        if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
            _raise_or_fallback(feature_name="--guided-decoding-backend",
                               recommend_to_remove=False)
            return False

        # Need at least Ampere for now (FA support required).
1496
1497
1498
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1499
1500
        from vllm.platforms import current_platform
        if (current_platform.is_cuda()
1501
                and current_platform.get_device_capability()
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
            _raise_or_fallback(feature_name="--kv-cache-dtype",
                               recommend_to_remove=False)
            return False

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

        # No CPU offloading yet.
        if self.cpu_offload_gb != EngineArgs.cpu_offload_gb:
            _raise_or_fallback(feature_name="--cpu-offload-gb",
                               recommend_to_remove=False)
            return False

        # Only Fp16 and Bf16 dtypes since we only support FA.
        V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
        if model_config.dtype not in V1_SUPPORTED_DTYPES:
            _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
                               recommend_to_remove=False)
            return False

        # Some quantization is not compatible with torch.compile.
        V1_UNSUPPORTED_QUANT = ["bitsandbytes", "gguf"]
        if model_config.quantization in V1_UNSUPPORTED_QUANT:
            _raise_or_fallback(
                feature_name=f"--quantization {model_config.quantization}",
                recommend_to_remove=False)
            return False

        # No Embedding Models so far.
        if model_config.task not in ["generate"]:
            _raise_or_fallback(feature_name=f"--task {model_config.task}",
                               recommend_to_remove=False)
            return False

        # No Mamba or Encoder-Decoder so far.
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

        # No TransformersModel support so far.
        if (model_config.model_impl == ModelImpl.TRANSFORMERS
                or model_config.model_impl == "transformers"):
            _raise_or_fallback(
                feature_name=f"model_impl={model_config.model_impl}",
                recommend_to_remove=False)
            return False

        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
                != EngineArgs.max_num_partial_prefills
                or self.max_long_partial_prefills
                != EngineArgs.max_long_partial_prefills
                or self.long_prefill_token_threshold
                != EngineArgs.long_prefill_token_threshold):
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

        # Only Ngram speculative decoding so far.
        if (self.speculative_model is not None
                or self.num_speculative_tokens is not None):
            # This is supported but experimental (handled below).
            if self.speculative_model == "[ngram]":
                pass
            else:
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

        # No Disaggregated Prefill so far.
        if self.kv_transfer_config != EngineArgs.kv_transfer_config:
            _raise_or_fallback(feature_name="--kv-transfer-config",
                               recommend_to_remove=False)
            return False

        # No FlashInfer or XFormers so far.
        V1_BACKENDS = [
            "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
1597
            "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
1598
1599
1600
1601
1602
1603
1604
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1605
1606
1607
1608
1609
1610
1611
        # No support for device type other than CUDA, AMD (experiemntal) or
        # TPU (experimental) so far.
        if not (current_platform.is_cuda_alike() or current_platform.is_tpu()):
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1612
1613
1614
        #############################################################
        # Experimental Features - allow users to opt in.

1615
1616
1617
1618
1619
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
        # LoRA is supported on V1, but off by default for now.
        if self.enable_lora and _warn_or_fallback("LORA"):
            return False

        # PP is supported on V1, but off by default for now.
        if self.pipeline_parallel_size > 1 and _warn_or_fallback("PP"):
            return False

        # ngram is supported on V1, but off by default for now.
        if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"):
            return False

        # Non-CUDA is supported on V1, but off by default for now.
        not_cuda = not current_platform.is_cuda()
        if not_cuda and _warn_or_fallback(  # noqa: SIM103
                current_platform.device_type):
            return False
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                from vllm.platforms import current_platform
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

        # Disable prefix caching for multimodal models for VLLM_V0.
        if (model_config.is_multimodal_model and self.enable_prefix_caching):
            logger.warning(
                "--enable-prefix-caching is not supported for multimodal "
                "models in V0 and has been disabled.")
            self.enable_prefix_caching = False

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

    def _set_default_args_v1(self, usage_context: UsageContext) -> None:
        """Set Default Arguments for V1 Engine."""
1699

1700
1701
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
1702
1703
1704
1705
1706

        # V1 enables prefix caching by default.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = True

1707
1708
1709
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1710
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1711

1712
1713
        # When no user override, set the default values based on the usage
        # context.
1714
        # Use different default values for different hardware.
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
            from vllm.platforms import current_platform
            device_name = current_platform.get_device_name().lower()
        except Exception:
            # This is only used to set default_max_num_batched_tokens
            device_name = "no-device"

1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
        if "h100" in device_name or "h200" in device_name:
            # For H100 and H200, we use larger default values.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }

1741
        use_context_value = usage_context.value if usage_context else None
1742
1743
1744
1745
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1746
            logger.debug(
1747
                "Setting max_num_batched_tokens to %d for %s usage context.",
1748
                self.max_num_batched_tokens, use_context_value)
1749

1750
1751
1752
1753
1754
1755
        default_max_num_seqs = 1024
        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1756

1757

1758
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1759
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1760
    """Arguments for asynchronous vLLM engine."""
1761
    disable_log_requests: bool = False
1762
1763

    @staticmethod
1764
1765
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1766
1767
1768
1769
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1770
1771
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1772
1773
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1774
                            help='Disable logging requests.')
1775
1776
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update(parser)
1777
        return parser
1778
1779


1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1807
1808
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1809
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1810
1811
1812


def _async_engine_args_parser():
1813
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1814
                                        async_args_only=True)