arg_utils.py 81.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import argparse
4
import dataclasses
5
import json
6
import threading
7
from dataclasses import dataclass
8
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
9
                    Tuple, Type, Union, cast, get_args)
10

11
12
import torch

13
import vllm.envs as envs
14
from vllm import version
15
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
16
17
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
18
19
20
21
                         ModelConfig, ModelImpl, ObservabilityConfig,
                         ParallelConfig, PoolerConfig, PromptAdapterConfig,
                         SchedulerConfig, SpeculativeConfig, TaskOption,
                         TokenizerPoolConfig, VllmConfig)
22
from vllm.executor.executor_base import ExecutorBase
23
from vllm.logger import init_logger
24
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
25
from vllm.plugins import load_general_plugins
26
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
27
from vllm.transformers_utils.utils import check_gguf_file
28
from vllm.usage.usage_lib import UsageContext
29
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
30

31
if TYPE_CHECKING:
32
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
33

34
35
logger = init_logger(__name__)

36
37
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

38
39
40
41
42
43
44
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "tpu",
    "xpu",
45
    "hpu",
46
47
]

48

49
50
51
52
53
54
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


55
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
56
57
58
59
60
61
62
63
64
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
65
66
67
68
69
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
70
71
72
73
74
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
75
76

        try:
77
            parsed_value = int(value)
78
79
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
80
81
82
83
84
85
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
86
87
88
89

    return out_dict


90
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
91
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
92
    """Arguments for vLLM engine."""
93
    model: str = 'facebook/opt-125m'
94
    served_model_name: Optional[Union[str, List[str]]] = None
95
    tokenizer: Optional[str] = None
96
    hf_config_path: Optional[str] = None
97
    task: TaskOption = "auto"
98
    skip_tokenizer_init: bool = False
99
    tokenizer_mode: str = 'auto'
100
    trust_remote_code: bool = False
101
    allowed_local_media_path: str = ""
102
    download_dir: Optional[str] = None
103
    load_format: str = 'auto'
104
    config_format: ConfigFormat = ConfigFormat.AUTO
105
    dtype: str = 'auto'
106
    kv_cache_dtype: str = 'auto'
107
    seed: Optional[int] = None
108
    max_model_len: Optional[int] = None
109
110
111
112
113
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
114
    # number of P/D disaggregation (or other disaggregation) workers
115
116
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
117
    enable_expert_parallel: bool = False
118
    max_parallel_loading_workers: Optional[int] = None
119
    block_size: Optional[int] = None
120
    enable_prefix_caching: Optional[bool] = None
121
    disable_sliding_window: bool = False
122
    disable_cascade_attn: bool = False
123
    use_v2_block_manager: bool = True
124
125
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
126
    gpu_memory_utilization: float = 0.90
127
    max_num_batched_tokens: Optional[int] = None
128
129
130
    max_num_partial_prefills: Optional[int] = 1
    max_long_partial_prefills: Optional[int] = 1
    long_prefill_token_threshold: Optional[int] = 0
131
    max_num_seqs: Optional[int] = None
132
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
133
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
134
    revision: Optional[str] = None
135
    code_revision: Optional[str] = None
136
    rope_scaling: Optional[Dict[str, Any]] = None
137
    rope_theta: Optional[float] = None
138
    hf_overrides: Optional[HfOverrides] = None
139
    tokenizer_revision: Optional[str] = None
140
    quantization: Optional[str] = None
141
    enforce_eager: Optional[bool] = None
142
    max_seq_len_to_capture: int = 8192
143
    disable_custom_all_reduce: bool = False
144
    tokenizer_pool_size: int = 0
145
146
147
148
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
149
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
150
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
151
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
152
    disable_mm_preprocessor_cache: bool = False
153
    enable_lora: bool = False
154
    enable_lora_bias: bool = False
155
156
    max_loras: int = 1
    max_lora_rank: int = 16
157
158
159
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
160
    fully_sharded_loras: bool = False
161
    lora_extra_vocab_size: int = 256
162
    long_lora_scaling_factors: Optional[Tuple[float]] = None
163
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
164
    max_cpu_loras: Optional[int] = None
165
    device: str = 'auto'
166
    num_scheduler_steps: int = 1
167
    multi_step_stream_outputs: bool = True
168
    ray_workers_use_nsight: bool = False
169
    num_gpu_blocks_override: Optional[int] = None
170
    num_lookahead_slots: int = 0
171
    model_loader_extra_config: Optional[dict] = None
172
    ignore_patterns: Optional[Union[str, List[str]]] = None
173
    preemption_mode: Optional[str] = None
174

175
    scheduler_delay_factor: float = 0.0
176
    enable_chunked_prefill: Optional[bool] = None
177

178
    guided_decoding_backend: str = 'xgrammar'
179
    logits_processor_pattern: Optional[str] = None
180
181
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
182
    speculative_model_quantization: Optional[str] = None
183
    speculative_draft_tensor_parallel_size: Optional[int] = None
184
    num_speculative_tokens: Optional[int] = None
185
    speculative_disable_mqa_scorer: Optional[bool] = False
186
    speculative_max_model_len: Optional[int] = None
187
    speculative_disable_by_batch_size: Optional[int] = None
188
189
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
190
191
192
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
193
    qlora_adapter_name_or_path: Optional[str] = None
194
    disable_logprobs_during_spec_decoding: Optional[bool] = None
195

196
    show_hidden_metrics_for_version: Optional[str] = None
197
    otlp_traces_endpoint: Optional[str] = None
198
    collect_detailed_traces: Optional[str] = None
199
    disable_async_output_proc: bool = False
200
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
201
    scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
202

203
204
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
205
    compilation_config: Optional[CompilationConfig] = None
206
    worker_cls: str = "auto"
207
    worker_extension_cls: str = ""
208

209
210
    kv_transfer_config: Optional[KVTransferConfig] = None

211
    generation_config: Optional[str] = "auto"
212
    override_generation_config: Optional[Dict[str, Any]] = None
213
    enable_sleep_mode: bool = False
214
    model_impl: str = "auto"
215

216
217
    calculate_kv_scales: Optional[bool] = None

218
    additional_config: Optional[Dict[str, Any]] = None
219
220
    enable_reasoning: Optional[bool] = None
    reasoning_parser: Optional[str] = None
221
    use_tqdm_on_load: bool = True
222

223
    def __post_init__(self):
224
        if not self.tokenizer:
225
            self.tokenizer = self.model
226

227
228
229
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
230
        if isinstance(self.compilation_config, (int, dict)):
231
232
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
233

234
        # Setup plugins
235
236
        from vllm.plugins import load_general_plugins
        load_general_plugins()
237
238

    @staticmethod
239
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
240
        """Shared CLI arguments for vLLM engine."""
241
        # Model arguments
242
243
244
        parser.add_argument(
            '--model',
            type=str,
245
            default=EngineArgs.model,
246
            help='Name or path of the huggingface model to use.')
247
248
249
250
251
252
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
253
            'multiple tasks. When the model only supports one task, ``"auto"`` '
254
255
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
256
257
        parser.add_argument(
            '--tokenizer',
258
            type=nullable_str,
259
            default=EngineArgs.tokenizer,
260
261
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
262
263
264
265
266
267
        parser.add_argument(
            "--hf-config-path",
            type=nullable_str,
            default=EngineArgs.hf_config_path,
            help='Name or path of the huggingface config to use. '
            'If unspecified, model name or path will be used.')
268
269
270
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
271
272
273
            help='Skip initialization of tokenizer and detokenizer. '
            'Expects valid prompt_token_ids and None for prompt from '
            'the input. The generated output will contain token ids.')
Jasmond L's avatar
Jasmond L committed
274
275
        parser.add_argument(
            '--revision',
276
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
277
            default=None,
278
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
279
280
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
281
282
        parser.add_argument(
            '--code-revision',
283
            type=nullable_str,
284
            default=None,
285
            help='The specific revision to use for the model code on '
286
287
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
288
289
        parser.add_argument(
            '--tokenizer-revision',
290
            type=nullable_str,
291
            default=None,
292
293
294
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
295
296
297
298
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
299
            choices=['auto', 'slow', 'mistral', 'custom'],
300
301
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
302
            'always use the slow tokenizer. \n* '
303
304
305
            '"mistral" will always use the `mistral_common` tokenizer. \n* '
            '"custom" will use --tokenizer to select the '
            'preregistered tokenizer.')
306
307
        parser.add_argument('--trust-remote-code',
                            action='store_true',
308
                            help='Trust remote code from huggingface.')
309
310
311
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
312
313
314
315
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
316
        parser.add_argument('--download-dir',
317
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
318
                            default=EngineArgs.download_dir,
319
                            help='Directory to download and load the weights, '
320
                            'default to the default cache dir of '
321
                            'huggingface.')
322
323
324
325
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
326
            choices=[f.value for f in LoadFormat],
327
328
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
329
            'and fall back to the pytorch bin format if safetensors format '
330
331
332
333
334
335
336
337
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
338
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
339
            'section for more information.\n'
340
            '* "runai_streamer" will load the Safetensors weights using Run:ai'
341
            'Model Streamer.\n'
342
            '* "bitsandbytes" will load the weights using bitsandbytes '
343
344
345
346
347
348
349
            'quantization.\n'
            '* "sharded_state" will load weights from pre-sharded checkpoint '
            'files, supporting efficient loading of tensor-parallel models\n'
            '* "gguf" will load weights from GGUF format files (details '
            'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n'
            '* "mistral" will load weights from consolidated safetensors files '
            'used by Mistral models.\n')
350
351
352
353
354
355
356
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
357
358
359
360
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
361
362
363
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
364
365
366
367
368
369
370
371
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
372
373
374
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
375
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
376
            default=EngineArgs.kv_cache_dtype,
377
            help='Data type for kv cache storage. If "auto", will use model '
378
379
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
380
381
        parser.add_argument('--max-model-len',
                            type=int,
382
                            default=EngineArgs.max_model_len,
383
384
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
385
386
387
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
388
            default='xgrammar',
389
            help='Which engine will be used for guided decoding'
390
            ' (JSON schema / regex etc) by default. Currently support '
391
            'https://github.com/outlines-dev/outlines, '
392
            'https://github.com/mlc-ai/xgrammar, and '
393
394
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
395
            ' parameter.\n'
396
            'Backend-specific options can be supplied in a comma-separated '
397
398
            'list following a colon after the backend name. Valid backends and '
            'all available options are: [xgrammar:no-fallback, '
399
            'xgrammar:disable-any-whitespace, '
400
            'outlines:no-fallback, lm-format-enforcer:no-fallback]')
401
402
403
404
405
406
407
408
        parser.add_argument(
            '--logits-processor-pattern',
            type=nullable_str,
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
409
410
411
412
413
414
415
416
417
418
419
420
        parser.add_argument(
            '--model-impl',
            type=str,
            default=EngineArgs.model_impl,
            choices=[f.value for f in ModelImpl],
            help='Which implementation of the model to use.\n\n'
            '* "auto" will try to use the vLLM implementation if it exists '
            'and fall back to the Transformers implementation if no vLLM '
            'implementation is available.\n'
            '* "vllm" will use the vLLM model implementation.\n'
            '* "transformers" will use the Transformers model '
            'implementation.\n')
421
        # Parallel arguments
422
423
        parser.add_argument(
            '--distributed-executor-backend',
424
            choices=['ray', 'mp', 'uni', 'external_launcher'],
425
            default=EngineArgs.distributed_executor_backend,
426
427
428
429
430
431
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
432
            'only supports Ray for distributed inference.')
433

434
435
436
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
437
                            default=EngineArgs.pipeline_parallel_size,
438
                            help='Number of pipeline stages.')
439
440
441
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
442
                            default=EngineArgs.tensor_parallel_size,
443
                            help='Number of tensor parallel replicas.')
444
445
446
447
448
        parser.add_argument(
            '--enable-expert-parallel',
            action='store_true',
            help='Use expert parallelism instead of tensor parallelism '
            'for MoE layers.')
449
450
451
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
452
            default=EngineArgs.max_parallel_loading_workers,
453
            help='Load model sequentially in multiple batches, '
454
            'to avoid RAM OOM when using tensor '
455
            'parallel and large models.')
456
457
458
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
459
            help='If specified, use nsight to profile Ray workers.')
460
        # KV cache arguments
461
462
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
463
                            default=EngineArgs.block_size,
464
                            choices=[8, 16, 32, 64, 128],
465
                            help='Token block size for contiguous chunks of '
466
                            'tokens. This is ignored on neuron devices and '
467
                            'set to ``--max-model-len``. On CUDA devices, '
468
469
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
470

471
472
473
474
475
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
476
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
477
        )
478
479
480
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
481
                            'capping to sliding window size.')
482
483
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
484
                            default=True,
485
486
487
488
489
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
490
491
492
493
494
495
496
497
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
498

499
500
501
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
502
                            help='Random seed for operations.')
503
        parser.add_argument('--swap-space',
504
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
505
                            default=EngineArgs.swap_space,
506
                            help='CPU swap space size (GiB) per GPU.')
507
508
509
510
511
512
513
514
515
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
516
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
517
            'which requires at least 26GB GPU memory. Note that this '
518
            'requires fast CPU-GPU interconnect, as part of the model is '
519
520
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
521
522
523
524
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
525
526
527
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
528
529
530
531
532
533
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
534
        parser.add_argument(
535
            '--num-gpu-blocks-override',
536
537
538
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
539
            ' of GPU blocks. Used for testing preemption.')
540
541
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
542
                            default=EngineArgs.max_num_batched_tokens,
543
544
                            help='Maximum number of batched tokens per '
                            'iteration.')
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        parser.add_argument(
            "--max-num-partial-prefills",
            type=int,
            default=EngineArgs.max_num_partial_prefills,
            help="For chunked prefill, the max number of concurrent \
            partial prefills."
            "Defaults to 1",
        )
        parser.add_argument(
            "--max-long-partial-prefills",
            type=int,
            default=EngineArgs.max_long_partial_prefills,
            help="For chunked prefill, the maximum number of prompts longer "
            "than --long-prefill-token-threshold that will be prefilled "
            "concurrently. Setting this less than --max-num-partial-prefills "
            "will allow shorter prompts to jump the queue in front of longer "
            "prompts in some cases, improving latency. Defaults to 1.")
        parser.add_argument(
            "--long-prefill-token-threshold",
            type=float,
            default=EngineArgs.long_prefill_token_threshold,
            help="For chunked prefill, a request is considered long if the "
            "prompt is longer than this number of tokens. Defaults to 4%% of "
            "the model's context length.",
        )
570
571
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
572
                            default=EngineArgs.max_num_seqs,
573
                            help='Maximum number of sequences per iteration.')
574
575
576
577
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
578
579
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
580
581
        parser.add_argument('--disable-log-stats',
                            action='store_true',
582
                            help='Disable logging statistics.')
583
584
585
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
586
                            type=nullable_str,
587
                            choices=[*QUANTIZATION_METHODS, None],
588
                            default=EngineArgs.quantization,
589
590
591
592
593
594
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
595
596
597
598
599
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
600
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
601
602
603
604
605
606
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
607
608
609
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
610
                            help='Extra arguments for the HuggingFace config. '
611
612
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
613
614
615
616
617
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
618
        parser.add_argument('--max-seq-len-to-capture',
619
620
621
622
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
623
624
625
626
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
627
628
629
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
630
                            help='See ParallelConfig.')
631
632
633
634
635
636
637
638
639
640
641
642
643
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
644
                            type=nullable_str,
645
646
647
648
649
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
650
651
652
653
654
655
656
657
658
659
660
661
662
663

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
664
665
666
667
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
668
            help=('Overrides for the multimodal input mapping/processing, '
669
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
670
        parser.add_argument(
671
            '--disable-mm-preprocessor-cache',
672
            action='store_true',
673
674
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
675

676
677
678
679
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
680
681
682
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
702
            choices=['auto', 'float16', 'bfloat16'],
703
704
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
705
706
707
708
709
710
711
712
713
714
715
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
716
717
718
719
720
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
721
722
                  'Must be >= than max_loras. '
                  'Defaults to max_loras.'))
723
724
725
726
727
728
729
730
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
731
732
733
734
735
736
737
738
739
740
741
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
742
743
744
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
745
                            choices=DEVICE_OPTIONS,
746
                            help='Device type for vLLM execution.')
747
748
749
750
751
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
752
753
754
755
756
757
758
759
        parser.add_argument(
            '--use-tqdm-on-load',
            dest='use_tqdm_on_load',
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.use_tqdm_on_load,
            help='Whether to enable/disable progress bar '
            'when loading model weights.',
        )
760

761
762
        parser.add_argument(
            '--multi-step-stream-outputs',
763
764
765
766
767
768
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
769
770
771
772
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
773
            help='Apply a delay (of delay factor multiplied by previous '
774
            'prompt latency) before scheduling next prompt.')
775
776
        parser.add_argument(
            '--enable-chunked-prefill',
777
778
779
780
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
781
            help='If set, the prefill requests can be chunked based on the '
782
            'max_num_batched_tokens.')
783
784
785

        parser.add_argument(
            '--speculative-model',
786
            type=nullable_str,
787
            default=EngineArgs.speculative_model,
788
789
            help=
            'The name of the draft model to be used in speculative decoding.')
790
791
792
793
794
795
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
796
            help='Method used to quantize the weights of speculative model. '
797
798
799
800
801
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
802
803
804
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
805
            default=EngineArgs.num_speculative_tokens,
806
            help='The number of speculative tokens to sample from '
807
            'the draft model in speculative decoding.')
808
809
810
811
812
813
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
814
815
816
817
818
819
820
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
821

822
823
        parser.add_argument(
            '--speculative-max-model-len',
824
            type=int,
825
826
827
828
829
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

830
831
832
833
834
835
836
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

837
838
839
840
841
842
843
844
845
846
847
848
849
850
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

883
884
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
885
            action=StoreBoolean,
886
            default=EngineArgs.disable_logprobs_during_spec_decoding,
887
888
            nargs="?",
            const="True",
889
890
891
892
893
894
895
896
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

897
        parser.add_argument('--model-loader-extra-config',
898
                            type=nullable_str,
899
900
901
902
903
904
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
905
906
907
908
909
910
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
911
            "Default to `original/**/*` to avoid repeated loading of llama's "
912
            "checkpoints.")
913
        parser.add_argument(
914
            '--preemption-mode',
915
916
            type=str,
            default=None,
917
918
919
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
920

921
922
923
924
925
926
927
928
929
930
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
931
            "same as the ``--model`` argument. Noted that this name(s) "
932
            "will also be used in `model_name` tag content of "
933
            "prometheus metrics, if multiple names provided, metrics "
934
            "tag will take the first one.")
935
936
937
938
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
939

940
941
942
943
944
945
946
947
948
949
950
951
        parser.add_argument('--show-hidden-metrics-for-version',
                            type=str,
                            default=None,
                            help='Enable deprecated Prometheus metrics that '
                            'have been hidden since the specified version. '
                            'For example, if a previously deprecated metric '
                            'has been hidden since the v0.7.0 release, you '
                            'use --show-hidden-metrics-for-version=0.7 as a '
                            'temporary escape hatch while you migrate to new '
                            'metrics. The metric is likely to be removed '
                            'completely in an upcoming release.')

952
953
954
955
956
        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
957
958
959
960
961
962
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
963
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
964
965
966
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
967

968
969
970
971
972
973
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
974

975
976
977
978
979
980
981
982
983
984
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

985
986
987
988
989
990
991
        parser.add_argument(
            '--scheduler-cls',
            default=EngineArgs.scheduler_cls,
            help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
            'is the default scheduler. Can be a class directly or the path to '
            'a class of form "mod.custom_class".')

992
        parser.add_argument(
993
994
            '--override-neuron-config',
            type=json.loads,
995
            default=None,
996
            help="Override or set neuron device configuration. "
997
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
998
        parser.add_argument(
999
1000
            '--override-pooler-config',
            type=PoolerConfig.from_json,
1001
            default=None,
1002
            help="Override or set the pooling method for pooling models. "
1003
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
1004

1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
1017
1018
1019
1020
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
1021

1022
1023
1024
1025
1026
1027
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

1028
1029
1030
1031
1032
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
1033
1034
1035
1036
1037
1038
1039
        parser.add_argument(
            '--worker-extension-cls',
            type=str,
            default="",
            help='The worker extension class on top of the worker cls, '
            'it is useful if you just want to add new functions to the worker '
            'class without changing the existing functions.')
1040
1041
1042
        parser.add_argument(
            "--generation-config",
            type=nullable_str,
1043
            default="auto",
1044
            help="The folder path to the generation config. "
1045
1046
1047
1048
1049
            "Defaults to 'auto', the generation config will be loaded from "
            "model path. If set to 'vllm', no generation config is loaded, "
            "vLLM defaults will be used. If set to a folder path, the "
            "generation config will be loaded from the specified folder path. "
            "If `max_new_tokens` is specified in generation config, then "
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
            "it sets a server-wide limit on the number of output tokens "
            "for all requests.")

        parser.add_argument(
            "--override-generation-config",
            type=json.loads,
            default=None,
            help="Overrides or sets generation config in JSON format. "
            "e.g. ``{\"temperature\": 0.5}``. If used with "
            "--generation-config=auto, the override parameters will be merged "
            "with the default config from the model. If generation-config is "
            "None, only the override parameters are used.")
1062

1063
1064
1065
1066
1067
1068
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

1069
1070
1071
1072
1073
1074
1075
1076
1077
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')

1078
1079
1080
1081
1082
1083
1084
1085
        parser.add_argument(
            "--additional-config",
            type=json.loads,
            default=None,
            help="Additional config for specified platform in JSON format. "
            "Different platforms may support different configs. Make sure the "
            "configs are valid for the platform you are using. The input format"
            " is like '{\"config_key\":\"config_value\"}'")
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104

        parser.add_argument(
            "--enable-reasoning",
            action="store_true",
            default=False,
            help="Whether to enable reasoning_content for the model. "
            "If enabled, the model will be able to generate reasoning content."
        )

        parser.add_argument(
            "--reasoning-parser",
            type=str,
            choices=["deepseek_r1"],
            default=None,
            help=
            "Select the reasoning parser depending on the model that you're "
            "using. This is used to parse the reasoning content into OpenAI "
            "API format. Required for ``--enable-reasoning``.")

1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
        parser.add_argument(
            "--disable-cascade-attn",
            action="store_true",
            default=False,
            help="Disable cascade attention for V1. While cascade attention "
            "does not change the mathematical correctness, disabling it "
            "could be useful for preventing potential numerical issues. "
            "Note that even if this is set to False, cascade attention will be "
            "only used when the heuristic tells that it's beneficial.")

1115
        return parser
1116
1117

    @classmethod
1118
    def from_cli_args(cls, args: argparse.Namespace):
1119
1120
1121
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
1122
1123
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
1124

1125
    def create_model_config(self) -> ModelConfig:
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

1137
        return ModelConfig(
1138
            model=self.model,
1139
            hf_config_path=self.hf_config_path,
1140
            task=self.task,
1141
1142
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
1143
1144
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1145
            allowed_local_media_path=self.allowed_local_media_path,
1146
1147
1148
1149
1150
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1151
            rope_theta=self.rope_theta,
1152
            hf_overrides=self.hf_overrides,
1153
1154
1155
1156
1157
1158
1159
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
1160
            disable_cascade_attn=self.disable_cascade_attn,
1161
            skip_tokenizer_init=self.skip_tokenizer_init,
1162
            served_model_name=self.served_model_name,
1163
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1164
            use_async_output_proc=not self.disable_async_output_proc,
1165
            config_format=self.config_format,
1166
            mm_processor_kwargs=self.mm_processor_kwargs,
1167
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1168
1169
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1170
            logits_processor_pattern=self.logits_processor_pattern,
1171
            generation_config=self.generation_config,
1172
            override_generation_config=self.override_generation_config,
1173
            enable_sleep_mode=self.enable_sleep_mode,
1174
            model_impl=self.model_impl,
1175
        )
1176

1177
1178
    def create_load_config(self) -> LoadConfig:

1179
        if(self.qlora_adapter_name_or_path is not None) and \
1180
1181
            self.quantization != "bitsandbytes":
            raise ValueError(
1182
                "QLoRA adapter only support "
1183
1184
                f"'bitsandbytes' quantization, but got {self.quantization}")

1185
1186
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1187
1188
1189
1190
1191
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1192
            use_tqdm_on_load=self.use_tqdm_on_load,
1193
        )
1194

1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1205

1206
1207
1208
1209
1210
1211
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1212
1213
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update()
1214

1215
        device_config = DeviceConfig(device=self.device)
1216
1217
        model_config = self.create_model_config()

1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
            self._set_default_args_v1(usage_context)
        else:
            self._set_default_args_v0(model_config)
1240

1241
        cache_config = CacheConfig(
1242
            block_size=self.block_size,
1243
1244
1245
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1246
            is_attention_free=model_config.is_attention_free,
1247
1248
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1249
1250
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
1251
            calculate_kv_scales=self.calculate_kv_scales,
1252
        )
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1265
        parallel_config = ParallelConfig(
1266
1267
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1268
            enable_expert_parallel=self.enable_expert_parallel,
1269
1270
1271
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1272
1273
1274
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1275
            ),
1276
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1277
            placement_group=placement_group,
1278
1279
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1280
            worker_extension_cls=self.worker_extension_cls,
1281
        )
1282
1283
1284
1285
1286
1287

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
1288
1289
            speculative_model_quantization = \
                self.speculative_model_quantization,
1290
1291
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
1292
            num_speculative_tokens=self.num_speculative_tokens,
1293
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
1294
1295
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
1296
1297
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
1298
            disable_log_stats=self.disable_log_stats,
1299
1300
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1301
1302
1303
1304
1305
1306
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
1307
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
1308
1309
        )

1310
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1311
        # If the feature combo become valid
1312
1313
1314
1315
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1316
1317
1318
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1319
1320
1321
1322
1323
1324
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1325
1326
1327
1328
1329
1330
1331
1332
1333

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1334
        scheduler_config = SchedulerConfig(
1335
            runner_type=model_config.runner_type,
1336
1337
1338
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1339
            num_lookahead_slots=num_lookahead_slots,
1340
1341
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1342
            is_multimodal_model=model_config.is_multimodal_model,
1343
            preemption_mode=self.preemption_mode,
1344
            num_scheduler_steps=self.num_scheduler_steps,
1345
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1346
1347
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1348
            policy=self.scheduling_policy,
1349
            scheduler_cls=self.scheduler_cls,
1350
1351
1352
1353
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
        )
1354

1355
        lora_config = LoRAConfig(
1356
            bias_enabled=self.enable_lora_bias,
1357
1358
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1359
            fully_sharded_loras=self.fully_sharded_loras,
1360
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1361
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1362
1363
1364
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1365

1366
1367
1368
1369
1370
1371
1372
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1373
        load_config = self.create_load_config()
1374

1375
1376
1377
1378
1379
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1380
        decoding_config = DecodingConfig(
1381
1382
1383
1384
            guided_decoding_backend=self.guided_decoding_backend,
            reasoning_backend=self.reasoning_parser
            if self.enable_reasoning else None,
        )
1385

1386
1387
1388
1389
1390
        show_hidden_metrics = False
        if self.show_hidden_metrics_for_version is not None:
            show_hidden_metrics = version._prev_minor_version_was(
                self.show_hidden_metrics_for_version)

1391
1392
1393
1394
1395
1396
1397
1398
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1399
        observability_config = ObservabilityConfig(
1400
            show_hidden_metrics=show_hidden_metrics,
1401
1402
1403
1404
1405
1406
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1407

1408
        config = VllmConfig(
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1419
            prompt_adapter_config=prompt_adapter_config,
1420
            compilation_config=self.compilation_config,
1421
            kv_transfer_config=self.kv_transfer_config,
1422
            additional_config=self.additional_config,
1423
        )
1424

1425
1426
        return config

1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

        if (self.load_format == LoadFormat.TENSORIZER.value
                or self.load_format == LoadFormat.SHARDED_STATE.value):
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

        if self.preemption_mode != EngineArgs.preemption_mode:
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

        if self.scheduling_policy != EngineArgs.scheduling_policy:
            _raise_or_fallback(feature_name="--scheduling-policy",
                               recommend_to_remove=False)
            return False

        if self.num_scheduler_steps != EngineArgs.num_scheduler_steps:
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

        if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor:
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

        if self.additional_config != EngineArgs.additional_config:
            _raise_or_fallback(feature_name="--additional-config",
                               recommend_to_remove=False)
            return False

        # Only support Xgrammar for guided decoding so far.
1478
1479
1480
        SUPPORTED_GUIDED_DECODING = [
            "xgrammar", "xgrammar:disable-any-whitespace"
        ]
1481
1482
1483
1484
1485
1486
        if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
            _raise_or_fallback(feature_name="--guided-decoding-backend",
                               recommend_to_remove=False)
            return False

        # Need at least Ampere for now (FA support required).
1487
1488
1489
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1490
1491
        from vllm.platforms import current_platform
        if (current_platform.is_cuda()
1492
                and current_platform.get_device_capability()
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
            _raise_or_fallback(feature_name="--kv-cache-dtype",
                               recommend_to_remove=False)
            return False

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

        # No CPU offloading yet.
        if self.cpu_offload_gb != EngineArgs.cpu_offload_gb:
            _raise_or_fallback(feature_name="--cpu-offload-gb",
                               recommend_to_remove=False)
            return False

        # Only Fp16 and Bf16 dtypes since we only support FA.
        V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
        if model_config.dtype not in V1_SUPPORTED_DTYPES:
            _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
                               recommend_to_remove=False)
            return False

        # Some quantization is not compatible with torch.compile.
        V1_UNSUPPORTED_QUANT = ["bitsandbytes", "gguf"]
        if model_config.quantization in V1_UNSUPPORTED_QUANT:
            _raise_or_fallback(
                feature_name=f"--quantization {model_config.quantization}",
                recommend_to_remove=False)
            return False

        # No Embedding Models so far.
        if model_config.task not in ["generate"]:
            _raise_or_fallback(feature_name=f"--task {model_config.task}",
                               recommend_to_remove=False)
            return False

        # No Mamba or Encoder-Decoder so far.
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

        # No TransformersModel support so far.
        if (model_config.model_impl == ModelImpl.TRANSFORMERS
                or model_config.model_impl == "transformers"):
            _raise_or_fallback(
                feature_name=f"model_impl={model_config.model_impl}",
                recommend_to_remove=False)
            return False

        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
                != EngineArgs.max_num_partial_prefills
                or self.max_long_partial_prefills
                != EngineArgs.max_long_partial_prefills
                or self.long_prefill_token_threshold
                != EngineArgs.long_prefill_token_threshold):
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

        # Only Ngram speculative decoding so far.
        if (self.speculative_model is not None
                or self.num_speculative_tokens is not None):
            # This is supported but experimental (handled below).
            if self.speculative_model == "[ngram]":
                pass
            else:
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

        # No Disaggregated Prefill so far.
        if self.kv_transfer_config != EngineArgs.kv_transfer_config:
            _raise_or_fallback(feature_name="--kv-transfer-config",
                               recommend_to_remove=False)
            return False

        # No FlashInfer or XFormers so far.
        V1_BACKENDS = [
            "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
1588
            "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
1589
1590
1591
1592
1593
1594
1595
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1596
1597
1598
1599
1600
1601
1602
        # No support for device type other than CUDA, AMD (experiemntal) or
        # TPU (experimental) so far.
        if not (current_platform.is_cuda_alike() or current_platform.is_tpu()):
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1603
1604
1605
        #############################################################
        # Experimental Features - allow users to opt in.

1606
1607
1608
1609
1610
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
        # LoRA is supported on V1, but off by default for now.
        if self.enable_lora and _warn_or_fallback("LORA"):
            return False

        # PP is supported on V1, but off by default for now.
        if self.pipeline_parallel_size > 1 and _warn_or_fallback("PP"):
            return False

        # ngram is supported on V1, but off by default for now.
        if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"):
            return False

        # Non-CUDA is supported on V1, but off by default for now.
        not_cuda = not current_platform.is_cuda()
        if not_cuda and _warn_or_fallback(  # noqa: SIM103
                current_platform.device_type):
            return False
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                from vllm.platforms import current_platform
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

        # Disable prefix caching for multimodal models for VLLM_V0.
        if (model_config.is_multimodal_model and self.enable_prefix_caching):
            logger.warning(
                "--enable-prefix-caching is not supported for multimodal "
                "models in V0 and has been disabled.")
            self.enable_prefix_caching = False

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

    def _set_default_args_v1(self, usage_context: UsageContext) -> None:
        """Set Default Arguments for V1 Engine."""
1690

1691
1692
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
1693
1694
1695
1696
1697

        # V1 enables prefix caching by default.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = True

1698
1699
1700
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1701
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1702

1703
1704
        # When no user override, set the default values based on the usage
        # context.
1705
        # Use different default values for different hardware.
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
            from vllm.platforms import current_platform
            device_name = current_platform.get_device_name().lower()
        except Exception:
            # This is only used to set default_max_num_batched_tokens
            device_name = "no-device"

1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
        if "h100" in device_name or "h200" in device_name:
            # For H100 and H200, we use larger default values.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }

1732
        use_context_value = usage_context.value if usage_context else None
1733
1734
1735
1736
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1737
            logger.debug(
1738
                "Setting max_num_batched_tokens to %d for %s usage context.",
1739
                self.max_num_batched_tokens, use_context_value)
1740

1741
1742
1743
1744
1745
1746
        default_max_num_seqs = 1024
        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1747

1748

1749
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1750
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1751
    """Arguments for asynchronous vLLM engine."""
1752
    disable_log_requests: bool = False
1753
1754

    @staticmethod
1755
1756
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1757
1758
1759
1760
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1761
1762
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1763
1764
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1765
                            help='Disable logging requests.')
1766
1767
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update(parser)
1768
        return parser
1769
1770


1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1798
1799
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1800
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1801
1802
1803


def _async_engine_args_parser():
1804
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1805
                                        async_args_only=True)