arg_utils.py 78.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import argparse
4
import dataclasses
5
import json
6
import threading
7
from dataclasses import dataclass
8
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
9
                    Tuple, Type, Union, cast, get_args)
10

11
12
import torch

13
import vllm.envs as envs
14
from vllm import version
15
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
16
17
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
18
19
20
21
                         ModelConfig, ModelImpl, ObservabilityConfig,
                         ParallelConfig, PoolerConfig, PromptAdapterConfig,
                         SchedulerConfig, SpeculativeConfig, TaskOption,
                         TokenizerPoolConfig, VllmConfig)
22
from vllm.executor.executor_base import ExecutorBase
23
from vllm.logger import init_logger
24
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
25
from vllm.plugins import load_general_plugins
26
from vllm.reasoning import ReasoningParserManager
27
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
28
from vllm.transformers_utils.utils import check_gguf_file
29
from vllm.usage.usage_lib import UsageContext
30
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
31

32
if TYPE_CHECKING:
33
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
34

35
36
logger = init_logger(__name__)

37
38
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

39
40
41
42
43
44
45
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "tpu",
    "xpu",
46
    "hpu",
47
48
]

49

50
51
52
53
54
55
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


56
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
57
58
59
60
61
62
63
64
65
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
66
67
68
69
70
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
71
72
73
74
75
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
76
77

        try:
78
            parsed_value = int(value)
79
80
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
81
82
83
84
85
86
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
87
88
89
90

    return out_dict


91
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
92
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
93
    """Arguments for vLLM engine."""
94
    model: str = 'facebook/opt-125m'
95
    served_model_name: Optional[Union[str, List[str]]] = None
96
    tokenizer: Optional[str] = None
97
    hf_config_path: Optional[str] = None
98
    task: TaskOption = "auto"
99
    skip_tokenizer_init: bool = False
100
    tokenizer_mode: str = 'auto'
101
    trust_remote_code: bool = False
102
    allowed_local_media_path: str = ""
103
    download_dir: Optional[str] = None
104
    load_format: str = 'auto'
105
    config_format: ConfigFormat = ConfigFormat.AUTO
106
    dtype: str = 'auto'
107
    kv_cache_dtype: str = 'auto'
108
    seed: Optional[int] = None
109
    max_model_len: Optional[int] = None
110
111
112
113
114
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
115
    # number of P/D disaggregation (or other disaggregation) workers
116
117
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
118
    data_parallel_size: int = 1
119
    enable_expert_parallel: bool = False
120
    max_parallel_loading_workers: Optional[int] = None
121
    block_size: Optional[int] = None
122
    enable_prefix_caching: Optional[bool] = None
123
    prefix_caching_hash_algo: str = "builtin"
124
    disable_sliding_window: bool = False
125
    disable_cascade_attn: bool = False
126
    use_v2_block_manager: bool = True
127
128
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
129
    gpu_memory_utilization: float = 0.90
130
    max_num_batched_tokens: Optional[int] = None
131
132
133
    max_num_partial_prefills: Optional[int] = 1
    max_long_partial_prefills: Optional[int] = 1
    long_prefill_token_threshold: Optional[int] = 0
134
    max_num_seqs: Optional[int] = None
135
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
136
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
137
    revision: Optional[str] = None
138
    code_revision: Optional[str] = None
139
    rope_scaling: Optional[Dict[str, Any]] = None
140
    rope_theta: Optional[float] = None
141
    hf_overrides: Optional[HfOverrides] = None
142
    tokenizer_revision: Optional[str] = None
143
    quantization: Optional[str] = None
144
    enforce_eager: Optional[bool] = None
145
    max_seq_len_to_capture: int = 8192
146
    disable_custom_all_reduce: bool = False
147
    tokenizer_pool_size: int = 0
148
149
150
151
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
152
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
153
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
154
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
155
    disable_mm_preprocessor_cache: bool = False
156
    enable_lora: bool = False
157
    enable_lora_bias: bool = False
158
159
    max_loras: int = 1
    max_lora_rank: int = 16
160
161
162
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
163
    fully_sharded_loras: bool = False
164
    lora_extra_vocab_size: int = 256
165
    long_lora_scaling_factors: Optional[Tuple[float]] = None
166
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
167
    max_cpu_loras: Optional[int] = None
168
    device: str = 'auto'
169
    num_scheduler_steps: int = 1
170
    multi_step_stream_outputs: bool = True
171
    ray_workers_use_nsight: bool = False
172
    num_gpu_blocks_override: Optional[int] = None
173
    num_lookahead_slots: int = 0
174
    model_loader_extra_config: Optional[dict] = None
175
    ignore_patterns: Optional[Union[str, List[str]]] = None
176
    preemption_mode: Optional[str] = None
177

178
    scheduler_delay_factor: float = 0.0
179
    enable_chunked_prefill: Optional[bool] = None
180

181
    guided_decoding_backend: str = 'xgrammar'
182
    logits_processor_pattern: Optional[str] = None
183

184
    speculative_config: Optional[Dict[str, Any]] = None
185

186
    qlora_adapter_name_or_path: Optional[str] = None
187
    show_hidden_metrics_for_version: Optional[str] = None
188
    otlp_traces_endpoint: Optional[str] = None
189
    collect_detailed_traces: Optional[str] = None
190
    disable_async_output_proc: bool = False
191
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
192
    scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
193

194
195
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
196
    compilation_config: Optional[CompilationConfig] = None
197
    worker_cls: str = "auto"
198
    worker_extension_cls: str = ""
199

200
201
    kv_transfer_config: Optional[KVTransferConfig] = None

202
    generation_config: Optional[str] = "auto"
203
    override_generation_config: Optional[Dict[str, Any]] = None
204
    enable_sleep_mode: bool = False
205
    model_impl: str = "auto"
206

207
208
    calculate_kv_scales: Optional[bool] = None

209
    additional_config: Optional[Dict[str, Any]] = None
210
211
    enable_reasoning: Optional[bool] = None
    reasoning_parser: Optional[str] = None
212
    use_tqdm_on_load: bool = True
213

214
    def __post_init__(self):
215
        if not self.tokenizer:
216
            self.tokenizer = self.model
217

218
219
220
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
221
        if isinstance(self.compilation_config, (int, dict)):
222
223
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
224

225
        # Setup plugins
226
227
        from vllm.plugins import load_general_plugins
        load_general_plugins()
228
229

    @staticmethod
230
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
231
        """Shared CLI arguments for vLLM engine."""
232
        # Model arguments
233
234
235
        parser.add_argument(
            '--model',
            type=str,
236
            default=EngineArgs.model,
237
            help='Name or path of the huggingface model to use.')
238
239
240
241
242
243
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
244
            'multiple tasks. When the model only supports one task, ``"auto"`` '
245
246
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
247
248
        parser.add_argument(
            '--tokenizer',
249
            type=nullable_str,
250
            default=EngineArgs.tokenizer,
251
252
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
253
254
255
256
257
258
        parser.add_argument(
            "--hf-config-path",
            type=nullable_str,
            default=EngineArgs.hf_config_path,
            help='Name or path of the huggingface config to use. '
            'If unspecified, model name or path will be used.')
259
260
261
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
262
263
264
            help='Skip initialization of tokenizer and detokenizer. '
            'Expects valid prompt_token_ids and None for prompt from '
            'the input. The generated output will contain token ids.')
Jasmond L's avatar
Jasmond L committed
265
266
        parser.add_argument(
            '--revision',
267
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
268
            default=None,
269
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
270
271
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
272
273
        parser.add_argument(
            '--code-revision',
274
            type=nullable_str,
275
            default=None,
276
            help='The specific revision to use for the model code on '
277
278
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
279
280
        parser.add_argument(
            '--tokenizer-revision',
281
            type=nullable_str,
282
            default=None,
283
284
285
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
286
287
288
289
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
290
            choices=['auto', 'slow', 'mistral', 'custom'],
291
292
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
293
            'always use the slow tokenizer. \n* '
294
295
296
            '"mistral" will always use the `mistral_common` tokenizer. \n* '
            '"custom" will use --tokenizer to select the '
            'preregistered tokenizer.')
297
298
        parser.add_argument('--trust-remote-code',
                            action='store_true',
299
                            help='Trust remote code from huggingface.')
300
301
302
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
303
304
305
306
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
307
        parser.add_argument('--download-dir',
308
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
309
                            default=EngineArgs.download_dir,
310
                            help='Directory to download and load the weights.')
311
312
313
314
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
315
            choices=[f.value for f in LoadFormat],
316
317
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
318
            'and fall back to the pytorch bin format if safetensors format '
319
320
321
322
323
324
325
326
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
327
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
328
            'section for more information.\n'
329
            '* "runai_streamer" will load the Safetensors weights using Run:ai'
330
            'Model Streamer.\n'
331
            '* "bitsandbytes" will load the weights using bitsandbytes '
332
333
334
335
336
337
338
            'quantization.\n'
            '* "sharded_state" will load weights from pre-sharded checkpoint '
            'files, supporting efficient loading of tensor-parallel models\n'
            '* "gguf" will load weights from GGUF format files (details '
            'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n'
            '* "mistral" will load weights from consolidated safetensors files '
            'used by Mistral models.\n')
339
340
341
342
343
344
345
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
346
347
348
349
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
350
351
352
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
353
354
355
356
357
358
359
360
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
361
362
363
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
364
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
365
            default=EngineArgs.kv_cache_dtype,
366
            help='Data type for kv cache storage. If "auto", will use model '
367
368
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
369
370
        parser.add_argument('--max-model-len',
                            type=int,
371
                            default=EngineArgs.max_model_len,
372
373
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
374
375
376
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
377
            default='xgrammar',
378
            help='Which engine will be used for guided decoding'
379
            ' (JSON schema / regex etc) by default. Currently support '
380
381
382
383
384
            'https://github.com/mlc-ai/xgrammar and '
            'https://github.com/guidance-ai/llguidance.'
            'Valid backend values are "xgrammar", "guidance", and "auto". '
            'With "auto", we will make opinionated choices based on request'
            'contents and what the backend libraries currently support, so '
385
            'the behavior is subject to change in each release.')
386
387
388
389
390
391
392
393
        parser.add_argument(
            '--logits-processor-pattern',
            type=nullable_str,
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
394
395
396
397
398
399
400
401
402
403
404
405
        parser.add_argument(
            '--model-impl',
            type=str,
            default=EngineArgs.model_impl,
            choices=[f.value for f in ModelImpl],
            help='Which implementation of the model to use.\n\n'
            '* "auto" will try to use the vLLM implementation if it exists '
            'and fall back to the Transformers implementation if no vLLM '
            'implementation is available.\n'
            '* "vllm" will use the vLLM model implementation.\n'
            '* "transformers" will use the Transformers model '
            'implementation.\n')
406
        # Parallel arguments
407
408
        parser.add_argument(
            '--distributed-executor-backend',
409
            choices=['ray', 'mp', 'uni', 'external_launcher'],
410
            default=EngineArgs.distributed_executor_backend,
411
412
413
414
415
416
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
417
            'only supports Ray for distributed inference.')
418

419
420
421
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
422
                            default=EngineArgs.pipeline_parallel_size,
423
                            help='Number of pipeline stages.')
424
425
426
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
427
                            default=EngineArgs.tensor_parallel_size,
428
                            help='Number of tensor parallel replicas.')
429
430
431
432
433
434
435
436
        parser.add_argument('--data-parallel-size',
                            '-dp',
                            type=int,
                            default=EngineArgs.data_parallel_size,
                            help='Number of data parallel replicas. '
                            'MoE layers will be sharded according to the '
                            'product of the tensor-parallel-size and '
                            'data-parallel-size.')
437
438
439
440
441
        parser.add_argument(
            '--enable-expert-parallel',
            action='store_true',
            help='Use expert parallelism instead of tensor parallelism '
            'for MoE layers.')
442
443
444
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
445
            default=EngineArgs.max_parallel_loading_workers,
446
            help='Load model sequentially in multiple batches, '
447
            'to avoid RAM OOM when using tensor '
448
            'parallel and large models.')
449
450
451
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
452
            help='If specified, use nsight to profile Ray workers.')
453
        # KV cache arguments
454
455
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
456
                            default=EngineArgs.block_size,
457
                            choices=[8, 16, 32, 64, 128],
458
                            help='Token block size for contiguous chunks of '
459
                            'tokens. This is ignored on neuron devices and '
460
                            'set to ``--max-model-len``. On CUDA devices, '
461
462
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
463

464
465
466
467
468
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
469
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
470
        )
471
472
473
474
475
476
477
        parser.add_argument(
            "--prefix-caching-hash-algo",
            type=str,
            choices=["builtin", "sha256"],
            default=EngineArgs.prefix_caching_hash_algo,
            help="Set the hash algorithm for prefix caching. "
            "Options are 'builtin' (Python's built-in hash) or 'sha256' "
478
            "(collision resistant but with certain overheads).",
479
        )
480
481
482
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
483
                            'capping to sliding window size.')
484
485
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
486
                            default=True,
487
488
489
490
491
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
492
493
494
495
496
497
498
499
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
500

501
502
503
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
504
                            help='Random seed for operations.')
505
        parser.add_argument('--swap-space',
506
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
507
                            default=EngineArgs.swap_space,
508
                            help='CPU swap space size (GiB) per GPU.')
509
510
511
512
513
514
515
516
517
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
518
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
519
            'which requires at least 26GB GPU memory. Note that this '
520
            'requires fast CPU-GPU interconnect, as part of the model is '
521
522
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
523
524
525
526
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
527
528
529
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
530
531
532
533
534
535
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
536
        parser.add_argument(
537
            '--num-gpu-blocks-override',
538
539
540
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
541
            ' of GPU blocks. Used for testing preemption.')
542
543
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
544
                            default=EngineArgs.max_num_batched_tokens,
545
546
                            help='Maximum number of batched tokens per '
                            'iteration.')
547
548
549
550
551
        parser.add_argument(
            "--max-num-partial-prefills",
            type=int,
            default=EngineArgs.max_num_partial_prefills,
            help="For chunked prefill, the max number of concurrent \
552
            partial prefills.")
553
554
555
556
557
558
559
560
        parser.add_argument(
            "--max-long-partial-prefills",
            type=int,
            default=EngineArgs.max_long_partial_prefills,
            help="For chunked prefill, the maximum number of prompts longer "
            "than --long-prefill-token-threshold that will be prefilled "
            "concurrently. Setting this less than --max-num-partial-prefills "
            "will allow shorter prompts to jump the queue in front of longer "
561
            "prompts in some cases, improving latency.")
562
563
564
565
566
        parser.add_argument(
            "--long-prefill-token-threshold",
            type=float,
            default=EngineArgs.long_prefill_token_threshold,
            help="For chunked prefill, a request is considered long if the "
567
            "prompt is longer than this number of tokens.")
568
569
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
570
                            default=EngineArgs.max_num_seqs,
571
                            help='Maximum number of sequences per iteration.')
572
573
574
575
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
576
577
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
578
579
        parser.add_argument('--disable-log-stats',
                            action='store_true',
580
                            help='Disable logging statistics.')
581
582
583
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
584
                            type=nullable_str,
585
                            choices=[*QUANTIZATION_METHODS, None],
586
                            default=EngineArgs.quantization,
587
588
589
590
591
592
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
593
594
595
596
597
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
598
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
599
600
601
602
603
604
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
605
606
607
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
608
                            help='Extra arguments for the HuggingFace config. '
609
610
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
611
612
613
614
615
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
616
        parser.add_argument('--max-seq-len-to-capture',
617
618
619
620
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
621
622
623
624
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
625
626
627
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
628
                            help='See ParallelConfig.')
629
630
631
632
633
634
635
636
637
638
639
640
641
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
642
                            type=nullable_str,
643
644
645
646
647
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
648
649
650
651
652
653
654

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
655
            # MultiModalConfig.get_limit_per_prompt
656
657
658
659
660
661
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
662
663
664
665
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
666
            help=('Overrides for the multimodal input mapping/processing, '
667
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
668
        parser.add_argument(
669
            '--disable-mm-preprocessor-cache',
670
            action='store_true',
671
672
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
673

674
675
676
677
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
678
679
680
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
700
            choices=['auto', 'float16', 'bfloat16'],
701
702
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
703
704
705
706
707
708
709
710
711
712
713
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
714
715
716
717
718
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
719
                  'Must be >= than max_loras.'))
720
721
722
723
724
725
726
727
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
728
729
730
731
732
733
734
735
736
737
738
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
739
740
741
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
742
                            choices=DEVICE_OPTIONS,
743
                            help='Device type for vLLM execution.')
744
745
746
747
748
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
749
750
751
752
753
754
755
756
        parser.add_argument(
            '--use-tqdm-on-load',
            dest='use_tqdm_on_load',
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.use_tqdm_on_load,
            help='Whether to enable/disable progress bar '
            'when loading model weights.',
        )
757

758
759
        parser.add_argument(
            '--multi-step-stream-outputs',
760
761
762
763
764
765
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
766
767
768
769
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
770
            help='Apply a delay (of delay factor multiplied by previous '
771
            'prompt latency) before scheduling next prompt.')
772
773
        parser.add_argument(
            '--enable-chunked-prefill',
774
775
776
777
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
778
            help='If set, the prefill requests can be chunked based on the '
779
            'max_num_batched_tokens.')
780
        parser.add_argument('--speculative-config',
781
                            type=json.loads,
782
783
784
                            default=None,
                            help='The configurations for speculative decoding.'
                            ' Should be a JSON string.')
785

786
        parser.add_argument('--model-loader-extra-config',
787
                            type=nullable_str,
788
789
790
791
792
793
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
794
795
796
797
798
799
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
800
            "Default to `original/**/*` to avoid repeated loading of llama's "
801
            "checkpoints.")
802
        parser.add_argument(
803
            '--preemption-mode',
804
805
            type=str,
            default=None,
806
807
808
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
809

810
811
812
813
814
815
816
817
818
819
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
820
            "same as the ``--model`` argument. Noted that this name(s) "
821
            "will also be used in `model_name` tag content of "
822
            "prometheus metrics, if multiple names provided, metrics "
823
            "tag will take the first one.")
824
825
826
827
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
828

829
830
831
832
833
834
835
836
837
838
839
840
        parser.add_argument('--show-hidden-metrics-for-version',
                            type=str,
                            default=None,
                            help='Enable deprecated Prometheus metrics that '
                            'have been hidden since the specified version. '
                            'For example, if a previously deprecated metric '
                            'has been hidden since the v0.7.0 release, you '
                            'use --show-hidden-metrics-for-version=0.7 as a '
                            'temporary escape hatch while you migrate to new '
                            'metrics. The metric is likely to be removed '
                            'completely in an upcoming release.')

841
842
843
844
845
        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
846
847
848
849
850
851
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
852
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
853
854
855
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
856

857
858
859
860
861
862
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
863

864
865
866
867
868
869
870
871
872
873
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

874
875
876
877
878
879
880
        parser.add_argument(
            '--scheduler-cls',
            default=EngineArgs.scheduler_cls,
            help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
            'is the default scheduler. Can be a class directly or the path to '
            'a class of form "mod.custom_class".')

881
        parser.add_argument(
882
883
            '--override-neuron-config',
            type=json.loads,
884
            default=None,
885
            help="Override or set neuron device configuration. "
886
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
887
        parser.add_argument(
888
889
            '--override-pooler-config',
            type=PoolerConfig.from_json,
890
            default=None,
891
            help="Override or set the pooling method for pooling models. "
892
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
893

894
895
896
897
898
899
900
901
902
903
904
905
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
906
907
908
909
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
910

911
912
913
914
915
916
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

917
918
919
920
921
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
922
923
924
925
926
927
928
        parser.add_argument(
            '--worker-extension-cls',
            type=str,
            default="",
            help='The worker extension class on top of the worker cls, '
            'it is useful if you just want to add new functions to the worker '
            'class without changing the existing functions.')
929
930
931
        parser.add_argument(
            "--generation-config",
            type=nullable_str,
932
            default="auto",
933
            help="The folder path to the generation config. "
934
935
936
937
938
            "Defaults to 'auto', the generation config will be loaded from "
            "model path. If set to 'vllm', no generation config is loaded, "
            "vLLM defaults will be used. If set to a folder path, the "
            "generation config will be loaded from the specified folder path. "
            "If `max_new_tokens` is specified in generation config, then "
939
940
941
942
943
944
945
946
947
948
949
950
            "it sets a server-wide limit on the number of output tokens "
            "for all requests.")

        parser.add_argument(
            "--override-generation-config",
            type=json.loads,
            default=None,
            help="Overrides or sets generation config in JSON format. "
            "e.g. ``{\"temperature\": 0.5}``. If used with "
            "--generation-config=auto, the override parameters will be merged "
            "with the default config from the model. If generation-config is "
            "None, only the override parameters are used.")
951

952
953
954
955
956
957
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

958
959
960
961
962
963
964
965
966
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')

967
968
969
970
971
972
973
974
        parser.add_argument(
            "--additional-config",
            type=json.loads,
            default=None,
            help="Additional config for specified platform in JSON format. "
            "Different platforms may support different configs. Make sure the "
            "configs are valid for the platform you are using. The input format"
            " is like '{\"config_key\":\"config_value\"}'")
975
976
977
978
979
980
981
982
983
984
985
986

        parser.add_argument(
            "--enable-reasoning",
            action="store_true",
            default=False,
            help="Whether to enable reasoning_content for the model. "
            "If enabled, the model will be able to generate reasoning content."
        )

        parser.add_argument(
            "--reasoning-parser",
            type=str,
987
            choices=list(ReasoningParserManager.reasoning_parsers),
988
989
990
991
992
993
            default=None,
            help=
            "Select the reasoning parser depending on the model that you're "
            "using. This is used to parse the reasoning content into OpenAI "
            "API format. Required for ``--enable-reasoning``.")

994
995
996
997
998
999
1000
1001
1002
1003
        parser.add_argument(
            "--disable-cascade-attn",
            action="store_true",
            default=False,
            help="Disable cascade attention for V1. While cascade attention "
            "does not change the mathematical correctness, disabling it "
            "could be useful for preventing potential numerical issues. "
            "Note that even if this is set to False, cascade attention will be "
            "only used when the heuristic tells that it's beneficial.")

1004
        return parser
1005
1006

    @classmethod
1007
    def from_cli_args(cls, args: argparse.Namespace):
1008
1009
1010
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
1011
1012
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
1013

1014
    def create_model_config(self) -> ModelConfig:
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

1026
        return ModelConfig(
1027
            model=self.model,
1028
            hf_config_path=self.hf_config_path,
1029
            task=self.task,
1030
1031
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
1032
1033
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1034
            allowed_local_media_path=self.allowed_local_media_path,
1035
1036
1037
1038
1039
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1040
            rope_theta=self.rope_theta,
1041
            hf_overrides=self.hf_overrides,
1042
1043
1044
1045
1046
1047
1048
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
1049
            disable_cascade_attn=self.disable_cascade_attn,
1050
            skip_tokenizer_init=self.skip_tokenizer_init,
1051
            served_model_name=self.served_model_name,
1052
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1053
            use_async_output_proc=not self.disable_async_output_proc,
1054
            config_format=self.config_format,
1055
            mm_processor_kwargs=self.mm_processor_kwargs,
1056
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1057
1058
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1059
            logits_processor_pattern=self.logits_processor_pattern,
1060
            generation_config=self.generation_config,
1061
            override_generation_config=self.override_generation_config,
1062
            enable_sleep_mode=self.enable_sleep_mode,
1063
            model_impl=self.model_impl,
1064
        )
1065

1066
1067
    def create_load_config(self) -> LoadConfig:

1068
        if(self.qlora_adapter_name_or_path is not None) and \
1069
1070
            self.quantization != "bitsandbytes":
            raise ValueError(
1071
                "QLoRA adapter only support "
1072
1073
                f"'bitsandbytes' quantization, but got {self.quantization}")

1074
1075
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1076
1077
1078
1079
1080
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1081
            use_tqdm_on_load=self.use_tqdm_on_load,
1082
        )
1083

1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1097
        dictionary from the engine.
1098
1099
        """
        if self.speculative_config is None:
1100
1101
            return None

1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
        speculative_config = SpeculativeConfig.from_dict(
            self.speculative_config)

        return speculative_config

1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1126

1127
1128
1129
1130
1131
1132
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1133
1134
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update()
1135

1136
        device_config = DeviceConfig(device=self.device)
1137
1138
        model_config = self.create_model_config()

1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
            self._set_default_args_v1(usage_context)
        else:
            self._set_default_args_v0(model_config)
1161

1162
1163
        assert self.enable_chunked_prefill is not None

1164
        cache_config = CacheConfig(
1165
            block_size=self.block_size,
1166
1167
1168
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1169
            is_attention_free=model_config.is_attention_free,
1170
1171
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1172
            enable_prefix_caching=self.enable_prefix_caching,
1173
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1174
            cpu_offload_gb=self.cpu_offload_gb,
1175
            calculate_kv_scales=self.calculate_kv_scales,
1176
        )
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1189
        parallel_config = ParallelConfig(
1190
1191
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1192
            data_parallel_size=self.data_parallel_size,
1193
            enable_expert_parallel=self.enable_expert_parallel,
1194
1195
1196
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1197
1198
1199
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1200
            ),
1201
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1202
            placement_group=placement_group,
1203
1204
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1205
            worker_extension_cls=self.worker_extension_cls,
1206
        )
1207

1208
        speculative_config = self.create_speculative_config(
1209
1210
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1211
            enable_chunked_prefill=self.enable_chunked_prefill,
1212
            disable_log_stats=self.disable_log_stats,
1213
1214
        )

1215
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1216
        # If the feature combo become valid
1217
1218
1219
1220
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1221
1222
1223
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1224
1225
1226
1227
1228
1229
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1230
1231
1232
1233
1234
1235
1236
1237
1238

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1239
        scheduler_config = SchedulerConfig(
1240
            runner_type=model_config.runner_type,
1241
1242
1243
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1244
            num_lookahead_slots=num_lookahead_slots,
1245
1246
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1247
            is_multimodal_model=model_config.is_multimodal_model,
1248
            preemption_mode=self.preemption_mode,
1249
            num_scheduler_steps=self.num_scheduler_steps,
1250
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1251
1252
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1253
            policy=self.scheduling_policy,
1254
            scheduler_cls=self.scheduler_cls,
1255
1256
1257
1258
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
        )
1259

1260
        lora_config = LoRAConfig(
1261
            bias_enabled=self.enable_lora_bias,
1262
1263
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1264
            fully_sharded_loras=self.fully_sharded_loras,
1265
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1266
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1267
1268
1269
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1270

1271
1272
1273
1274
1275
1276
1277
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1278
        load_config = self.create_load_config()
1279

1280
1281
1282
1283
1284
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1285
        decoding_config = DecodingConfig(
1286
1287
1288
1289
            guided_decoding_backend=self.guided_decoding_backend,
            reasoning_backend=self.reasoning_parser
            if self.enable_reasoning else None,
        )
1290

1291
1292
1293
1294
1295
        show_hidden_metrics = False
        if self.show_hidden_metrics_for_version is not None:
            show_hidden_metrics = version._prev_minor_version_was(
                self.show_hidden_metrics_for_version)

1296
1297
1298
1299
1300
1301
1302
1303
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1304
        observability_config = ObservabilityConfig(
1305
            show_hidden_metrics=show_hidden_metrics,
1306
1307
1308
1309
1310
1311
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1312

1313
        config = VllmConfig(
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1324
            prompt_adapter_config=prompt_adapter_config,
1325
            compilation_config=self.compilation_config,
1326
            kv_transfer_config=self.kv_transfer_config,
1327
            additional_config=self.additional_config,
1328
        )
1329

1330
1331
        return config

1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

        if (self.load_format == LoadFormat.TENSORIZER.value
                or self.load_format == LoadFormat.SHARDED_STATE.value):
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

        if self.preemption_mode != EngineArgs.preemption_mode:
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

        if self.scheduling_policy != EngineArgs.scheduling_policy:
            _raise_or_fallback(feature_name="--scheduling-policy",
                               recommend_to_remove=False)
            return False

        if self.num_scheduler_steps != EngineArgs.num_scheduler_steps:
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

        if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor:
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

        if self.additional_config != EngineArgs.additional_config:
            _raise_or_fallback(feature_name="--additional-config",
                               recommend_to_remove=False)
            return False

1382
        # Xgrammar and Guidance are supported.
1383
        SUPPORTED_GUIDED_DECODING = [
1384
1385
            "xgrammar", "xgrammar:disable-any-whitespace", "guidance",
            "guidance:disable-any-whitespace", "auto"
1386
        ]
1387
1388
1389
1390
1391
1392
        if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
            _raise_or_fallback(feature_name="--guided-decoding-backend",
                               recommend_to_remove=False)
            return False

        # Need at least Ampere for now (FA support required).
1393
1394
1395
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1396
1397
        from vllm.platforms import current_platform
        if (current_platform.is_cuda()
1398
                and current_platform.get_device_capability()
1399
1400
1401
1402
1403
1404
1405
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
            fp8_attention = self.kv_cache_dtype.startswith("fp8")
            will_use_fa = (
                current_platform.is_cuda()
                and not envs.is_set("VLLM_ATTENTION_BACKEND")
            ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
            supported = False
            if fp8_attention and will_use_fa:
                from vllm.vllm_flash_attn.fa_utils import (
                    flash_attn_supports_fp8)
                supported = flash_attn_supports_fp8()
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

        # Only Fp16 and Bf16 dtypes since we only support FA.
        V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
        if model_config.dtype not in V1_SUPPORTED_DTYPES:
            _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
                               recommend_to_remove=False)
            return False

        # Some quantization is not compatible with torch.compile.
1435
        V1_UNSUPPORTED_QUANT = ["gguf"]
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
        if model_config.quantization in V1_UNSUPPORTED_QUANT:
            _raise_or_fallback(
                feature_name=f"--quantization {model_config.quantization}",
                recommend_to_remove=False)
            return False

        # No Embedding Models so far.
        if model_config.task not in ["generate"]:
            _raise_or_fallback(feature_name=f"--task {model_config.task}",
                               recommend_to_remove=False)
            return False

        # No Mamba or Encoder-Decoder so far.
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
                != EngineArgs.max_num_partial_prefills
                or self.max_long_partial_prefills
1458
                != EngineArgs.max_long_partial_prefills):
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

        # Only Ngram speculative decoding so far.
1470
        is_ngram_enabled = False
1471
        is_eagle_enabled = False
1472
        if self.speculative_config is not None:
1473
            # This is supported but experimental (handled below).
1474
1475
1476
1477
1478
1479
            speculative_method = self.speculative_config.get("method")
            if speculative_method:
                if speculative_method in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
                elif speculative_method == "eagle":
                    is_eagle_enabled = True
1480
            else:
1481
1482
1483
1484
1485
                speculative_model = self.speculative_config.get("model")
                if speculative_model in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
            if not (is_ngram_enabled or is_eagle_enabled):
                # Other speculative decoding methods are not supported yet.
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

        # No Disaggregated Prefill so far.
        if self.kv_transfer_config != EngineArgs.kv_transfer_config:
            _raise_or_fallback(feature_name="--kv-transfer-config",
                               recommend_to_remove=False)
            return False

        # No FlashInfer or XFormers so far.
        V1_BACKENDS = [
            "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
1499
            "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
1500
1501
1502
1503
1504
1505
1506
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1507
1508
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1509
1510
1511
1512
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1513
1514
1515
        #############################################################
        # Experimental Features - allow users to opt in.

1516
1517
1518
1519
1520
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1521
1522
1523
        # PP is supported on V1 with Ray distributed executor,
        # but off for MP distributed executor for now.
        if (self.pipeline_parallel_size > 1
1524
1525
1526
                and self.distributed_executor_backend != "ray"):
            name = "Pipeline Parallelism without Ray distributed executor"
            _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1527
1528
1529
            return False

        # ngram is supported on V1, but off by default for now.
1530
        if is_ngram_enabled and _warn_or_fallback("ngram"):
1531
1532
            return False

1533
1534
1535
1536
        # Eagle is under development, so we don't support it yet.
        if is_eagle_enabled and _warn_or_fallback("Eagle"):
            return False

1537
1538
1539
        # Non-CUDA is supported on V1, but off by default for now.
        not_cuda = not current_platform.is_cuda()
        if not_cuda and _warn_or_fallback(  # noqa: SIM103
1540
                current_platform.device_name):
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
            return False
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                from vllm.platforms import current_platform
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1563
                use_spec_decode = self.speculative_config is not None
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching:
            # Disable prefix caching for multimodal models for VLLM_V0.
            if model_config.is_multimodal_model:
                logger.warning(
                    "--enable-prefix-caching is not supported for multimodal "
                    "models in V0 and has been disabled.")
                self.enable_prefix_caching = False

            # VLLM_V0 only supports builtin hash algo for prefix caching.
            if self.prefix_caching_hash_algo is None:
                self.prefix_caching_hash_algo = "builtin"
            elif self.prefix_caching_hash_algo == "sha256":
                raise ValueError(
                    "sha256 is not supported for prefix caching in V0 engine. "
                    "Please use 'builtin'.")
1607
1608
1609
1610
1611
1612
1613

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

    def _set_default_args_v1(self, usage_context: UsageContext) -> None:
        """Set Default Arguments for V1 Engine."""
1614

1615
1616
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
1617
1618
1619
1620
1621

        # V1 enables prefix caching by default.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = True

1622
1623
1624
1625
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
            self.prefix_caching_hash_algo = "builtin"

1626
1627
1628
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1629
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1630

1631
1632
        # When no user override, set the default values based on the usage
        # context.
1633
        # Use different default values for different hardware.
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
            from vllm.platforms import current_platform
            device_name = current_platform.get_device_name().lower()
        except Exception:
            # This is only used to set default_max_num_batched_tokens
            device_name = "no-device"

1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
        if "h100" in device_name or "h200" in device_name:
            # For H100 and H200, we use larger default values.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }

1660
        use_context_value = usage_context.value if usage_context else None
1661
1662
1663
1664
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1665
            logger.debug(
1666
                "Setting max_num_batched_tokens to %d for %s usage context.",
1667
                self.max_num_batched_tokens, use_context_value)
1668

1669
1670
1671
1672
1673
1674
        default_max_num_seqs = 1024
        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1675

1676

1677
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1678
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1679
    """Arguments for asynchronous vLLM engine."""
1680
    disable_log_requests: bool = False
1681
1682

    @staticmethod
1683
1684
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1685
1686
1687
1688
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1689
1690
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1691
1692
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1693
                            help='Disable logging requests.')
1694
1695
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update(parser)
1696
        return parser
1697
1698


1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1726
1727
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1728
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1729
1730
1731


def _async_engine_args_parser():
1732
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1733
                                        async_args_only=True)