arg_utils.py 78.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import argparse
4
import dataclasses
5
import json
6
import threading
7
from dataclasses import dataclass
8
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
9
                    Tuple, Type, Union, cast, get_args)
10

11
12
import torch

13
import vllm.envs as envs
14
from vllm import version
15
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
16
17
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
18
19
20
21
                         ModelConfig, ModelImpl, ObservabilityConfig,
                         ParallelConfig, PoolerConfig, PromptAdapterConfig,
                         SchedulerConfig, SpeculativeConfig, TaskOption,
                         TokenizerPoolConfig, VllmConfig)
22
from vllm.executor.executor_base import ExecutorBase
23
from vllm.logger import init_logger
24
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
25
from vllm.plugins import load_general_plugins
26
from vllm.reasoning import ReasoningParserManager
27
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
28
from vllm.transformers_utils.utils import check_gguf_file
29
from vllm.usage.usage_lib import UsageContext
30
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
31

32
if TYPE_CHECKING:
33
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
34

35
36
logger = init_logger(__name__)

37
38
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

39
40
41
42
43
44
45
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "tpu",
    "xpu",
46
    "hpu",
47
48
]

49

50
51
52
53
54
55
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


56
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
57
58
59
60
61
62
63
64
65
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
66
67
68
69
70
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
71
72
73
74
75
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
76
77

        try:
78
            parsed_value = int(value)
79
80
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
81
82
83
84
85
86
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
87
88
89
90

    return out_dict


91
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
92
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
93
    """Arguments for vLLM engine."""
94
    model: str = 'facebook/opt-125m'
95
    served_model_name: Optional[Union[str, List[str]]] = None
96
    tokenizer: Optional[str] = None
97
    hf_config_path: Optional[str] = None
98
    task: TaskOption = "auto"
99
    skip_tokenizer_init: bool = False
100
    tokenizer_mode: str = 'auto'
101
    trust_remote_code: bool = False
102
    allowed_local_media_path: str = ""
103
    download_dir: Optional[str] = None
104
    load_format: str = 'auto'
105
    config_format: ConfigFormat = ConfigFormat.AUTO
106
    dtype: str = 'auto'
107
    kv_cache_dtype: str = 'auto'
108
    seed: Optional[int] = None
109
    max_model_len: Optional[int] = None
110
111
112
113
114
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
115
    # number of P/D disaggregation (or other disaggregation) workers
116
117
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
118
    data_parallel_size: int = 1
119
    enable_expert_parallel: bool = False
120
    max_parallel_loading_workers: Optional[int] = None
121
    block_size: Optional[int] = None
122
    enable_prefix_caching: Optional[bool] = None
123
    prefix_caching_hash_algo: str = "builtin"
124
    disable_sliding_window: bool = False
125
    disable_cascade_attn: bool = False
126
    use_v2_block_manager: bool = True
127
128
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
129
    gpu_memory_utilization: float = 0.90
130
    max_num_batched_tokens: Optional[int] = None
131
132
133
    max_num_partial_prefills: Optional[int] = 1
    max_long_partial_prefills: Optional[int] = 1
    long_prefill_token_threshold: Optional[int] = 0
134
    max_num_seqs: Optional[int] = None
135
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
136
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
137
    revision: Optional[str] = None
138
    code_revision: Optional[str] = None
139
    rope_scaling: Optional[Dict[str, Any]] = None
140
    rope_theta: Optional[float] = None
141
    hf_token: Optional[Union[bool, str]] = None
142
    hf_overrides: Optional[HfOverrides] = None
143
    tokenizer_revision: Optional[str] = None
144
    quantization: Optional[str] = None
145
    enforce_eager: Optional[bool] = None
146
    max_seq_len_to_capture: int = 8192
147
    disable_custom_all_reduce: bool = False
148
    tokenizer_pool_size: int = 0
149
150
151
152
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
153
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
154
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
155
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
156
    disable_mm_preprocessor_cache: bool = False
157
    enable_lora: bool = False
158
    enable_lora_bias: bool = False
159
160
    max_loras: int = 1
    max_lora_rank: int = 16
161
162
163
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
164
    fully_sharded_loras: bool = False
165
    lora_extra_vocab_size: int = 256
166
    long_lora_scaling_factors: Optional[Tuple[float]] = None
167
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
168
    max_cpu_loras: Optional[int] = None
169
    device: str = 'auto'
170
    num_scheduler_steps: int = 1
171
    multi_step_stream_outputs: bool = True
172
    ray_workers_use_nsight: bool = False
173
    num_gpu_blocks_override: Optional[int] = None
174
    num_lookahead_slots: int = 0
175
    model_loader_extra_config: Optional[dict] = None
176
    ignore_patterns: Optional[Union[str, List[str]]] = None
177
    preemption_mode: Optional[str] = None
178

179
    scheduler_delay_factor: float = 0.0
180
    enable_chunked_prefill: Optional[bool] = None
181

182
    guided_decoding_backend: str = 'xgrammar'
183
    logits_processor_pattern: Optional[str] = None
184

185
    speculative_config: Optional[Dict[str, Any]] = None
186

187
    qlora_adapter_name_or_path: Optional[str] = None
188
    show_hidden_metrics_for_version: Optional[str] = None
189
    otlp_traces_endpoint: Optional[str] = None
190
    collect_detailed_traces: Optional[str] = None
191
    disable_async_output_proc: bool = False
192
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
193
    scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
194

195
196
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
197
    compilation_config: Optional[CompilationConfig] = None
198
    worker_cls: str = "auto"
199
    worker_extension_cls: str = ""
200

201
202
    kv_transfer_config: Optional[KVTransferConfig] = None

203
    generation_config: Optional[str] = "auto"
204
    override_generation_config: Optional[Dict[str, Any]] = None
205
    enable_sleep_mode: bool = False
206
    model_impl: str = "auto"
207

208
209
    calculate_kv_scales: Optional[bool] = None

210
    additional_config: Optional[Dict[str, Any]] = None
211
212
    enable_reasoning: Optional[bool] = None
    reasoning_parser: Optional[str] = None
213
    use_tqdm_on_load: bool = True
214

215
    def __post_init__(self):
216
        if not self.tokenizer:
217
            self.tokenizer = self.model
218

219
220
221
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
222
        if isinstance(self.compilation_config, (int, dict)):
223
224
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
225

226
        # Setup plugins
227
228
        from vllm.plugins import load_general_plugins
        load_general_plugins()
229
230

    @staticmethod
231
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
232
        """Shared CLI arguments for vLLM engine."""
233
        # Model arguments
234
235
236
        parser.add_argument(
            '--model',
            type=str,
237
            default=EngineArgs.model,
238
            help='Name or path of the huggingface model to use.')
239
240
241
242
243
244
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
245
            'multiple tasks. When the model only supports one task, ``"auto"`` '
246
247
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
248
249
        parser.add_argument(
            '--tokenizer',
250
            type=nullable_str,
251
            default=EngineArgs.tokenizer,
252
253
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
254
255
256
257
258
259
        parser.add_argument(
            "--hf-config-path",
            type=nullable_str,
            default=EngineArgs.hf_config_path,
            help='Name or path of the huggingface config to use. '
            'If unspecified, model name or path will be used.')
260
261
262
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
263
264
265
            help='Skip initialization of tokenizer and detokenizer. '
            'Expects valid prompt_token_ids and None for prompt from '
            'the input. The generated output will contain token ids.')
Jasmond L's avatar
Jasmond L committed
266
267
        parser.add_argument(
            '--revision',
268
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
269
            default=None,
270
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
271
272
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
273
274
        parser.add_argument(
            '--code-revision',
275
            type=nullable_str,
276
            default=None,
277
            help='The specific revision to use for the model code on '
278
279
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
280
281
        parser.add_argument(
            '--tokenizer-revision',
282
            type=nullable_str,
283
            default=None,
284
285
286
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
287
288
289
290
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
291
            choices=['auto', 'slow', 'mistral', 'custom'],
292
293
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
294
            'always use the slow tokenizer. \n* '
295
296
297
            '"mistral" will always use the `mistral_common` tokenizer. \n* '
            '"custom" will use --tokenizer to select the '
            'preregistered tokenizer.')
298
299
        parser.add_argument('--trust-remote-code',
                            action='store_true',
300
                            help='Trust remote code from huggingface.')
301
302
303
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
304
305
306
307
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
308
        parser.add_argument('--download-dir',
309
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
310
                            default=EngineArgs.download_dir,
311
                            help='Directory to download and load the weights.')
312
313
314
315
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
316
            choices=[f.value for f in LoadFormat],
317
318
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
319
            'and fall back to the pytorch bin format if safetensors format '
320
321
322
323
324
325
326
327
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
328
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
329
            'section for more information.\n'
330
            '* "runai_streamer" will load the Safetensors weights using Run:ai'
331
            'Model Streamer.\n'
332
            '* "bitsandbytes" will load the weights using bitsandbytes '
333
334
335
336
337
338
339
            'quantization.\n'
            '* "sharded_state" will load weights from pre-sharded checkpoint '
            'files, supporting efficient loading of tensor-parallel models\n'
            '* "gguf" will load weights from GGUF format files (details '
            'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n'
            '* "mistral" will load weights from consolidated safetensors files '
            'used by Mistral models.\n')
340
341
342
343
344
345
346
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
347
348
349
350
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
351
352
353
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
354
355
356
357
358
359
360
361
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
362
363
364
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
365
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
366
            default=EngineArgs.kv_cache_dtype,
367
            help='Data type for kv cache storage. If "auto", will use model '
368
369
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
370
371
        parser.add_argument('--max-model-len',
                            type=int,
372
                            default=EngineArgs.max_model_len,
373
374
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
375
376
377
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
378
            default='xgrammar',
379
            help='Which engine will be used for guided decoding'
380
            ' (JSON schema / regex etc) by default. Currently support '
381
382
383
384
385
            'https://github.com/mlc-ai/xgrammar and '
            'https://github.com/guidance-ai/llguidance.'
            'Valid backend values are "xgrammar", "guidance", and "auto". '
            'With "auto", we will make opinionated choices based on request'
            'contents and what the backend libraries currently support, so '
386
            'the behavior is subject to change in each release.')
387
388
389
390
391
392
393
394
        parser.add_argument(
            '--logits-processor-pattern',
            type=nullable_str,
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
395
396
397
398
399
400
401
402
403
404
405
406
        parser.add_argument(
            '--model-impl',
            type=str,
            default=EngineArgs.model_impl,
            choices=[f.value for f in ModelImpl],
            help='Which implementation of the model to use.\n\n'
            '* "auto" will try to use the vLLM implementation if it exists '
            'and fall back to the Transformers implementation if no vLLM '
            'implementation is available.\n'
            '* "vllm" will use the vLLM model implementation.\n'
            '* "transformers" will use the Transformers model '
            'implementation.\n')
407
        # Parallel arguments
408
409
        parser.add_argument(
            '--distributed-executor-backend',
410
            choices=['ray', 'mp', 'uni', 'external_launcher'],
411
            default=EngineArgs.distributed_executor_backend,
412
413
414
415
416
417
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
418
            'only supports Ray for distributed inference.')
419

420
421
422
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
423
                            default=EngineArgs.pipeline_parallel_size,
424
                            help='Number of pipeline stages.')
425
426
427
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
428
                            default=EngineArgs.tensor_parallel_size,
429
                            help='Number of tensor parallel replicas.')
430
431
432
433
434
435
436
437
        parser.add_argument('--data-parallel-size',
                            '-dp',
                            type=int,
                            default=EngineArgs.data_parallel_size,
                            help='Number of data parallel replicas. '
                            'MoE layers will be sharded according to the '
                            'product of the tensor-parallel-size and '
                            'data-parallel-size.')
438
439
440
441
442
        parser.add_argument(
            '--enable-expert-parallel',
            action='store_true',
            help='Use expert parallelism instead of tensor parallelism '
            'for MoE layers.')
443
444
445
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
446
            default=EngineArgs.max_parallel_loading_workers,
447
            help='Load model sequentially in multiple batches, '
448
            'to avoid RAM OOM when using tensor '
449
            'parallel and large models.')
450
451
452
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
453
            help='If specified, use nsight to profile Ray workers.')
454
        # KV cache arguments
455
456
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
457
                            default=EngineArgs.block_size,
458
                            choices=[8, 16, 32, 64, 128],
459
                            help='Token block size for contiguous chunks of '
460
                            'tokens. This is ignored on neuron devices and '
461
                            'set to ``--max-model-len``. On CUDA devices, '
462
463
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
464

465
466
467
468
469
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
470
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
471
        )
472
473
474
475
476
477
478
        parser.add_argument(
            "--prefix-caching-hash-algo",
            type=str,
            choices=["builtin", "sha256"],
            default=EngineArgs.prefix_caching_hash_algo,
            help="Set the hash algorithm for prefix caching. "
            "Options are 'builtin' (Python's built-in hash) or 'sha256' "
479
            "(collision resistant but with certain overheads).",
480
        )
481
482
483
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
484
                            'capping to sliding window size.')
485
486
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
487
                            default=True,
488
489
490
491
492
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
493
494
495
496
497
498
499
500
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
501

502
503
504
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
505
                            help='Random seed for operations.')
506
        parser.add_argument('--swap-space',
507
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
508
                            default=EngineArgs.swap_space,
509
                            help='CPU swap space size (GiB) per GPU.')
510
511
512
513
514
515
516
517
518
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
519
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
520
            'which requires at least 26GB GPU memory. Note that this '
521
            'requires fast CPU-GPU interconnect, as part of the model is '
522
523
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
524
525
526
527
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
528
529
530
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
531
532
533
534
535
536
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
537
        parser.add_argument(
538
            '--num-gpu-blocks-override',
539
540
541
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
542
            ' of GPU blocks. Used for testing preemption.')
543
544
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
545
                            default=EngineArgs.max_num_batched_tokens,
546
547
                            help='Maximum number of batched tokens per '
                            'iteration.')
548
549
550
551
552
        parser.add_argument(
            "--max-num-partial-prefills",
            type=int,
            default=EngineArgs.max_num_partial_prefills,
            help="For chunked prefill, the max number of concurrent \
553
            partial prefills.")
554
555
556
557
558
559
560
561
        parser.add_argument(
            "--max-long-partial-prefills",
            type=int,
            default=EngineArgs.max_long_partial_prefills,
            help="For chunked prefill, the maximum number of prompts longer "
            "than --long-prefill-token-threshold that will be prefilled "
            "concurrently. Setting this less than --max-num-partial-prefills "
            "will allow shorter prompts to jump the queue in front of longer "
562
            "prompts in some cases, improving latency.")
563
564
565
566
567
        parser.add_argument(
            "--long-prefill-token-threshold",
            type=float,
            default=EngineArgs.long_prefill_token_threshold,
            help="For chunked prefill, a request is considered long if the "
568
            "prompt is longer than this number of tokens.")
569
570
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
571
                            default=EngineArgs.max_num_seqs,
572
                            help='Maximum number of sequences per iteration.')
573
574
575
576
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
577
578
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
579
580
        parser.add_argument('--disable-log-stats',
                            action='store_true',
581
                            help='Disable logging statistics.')
582
583
584
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
585
                            type=nullable_str,
586
                            choices=[*QUANTIZATION_METHODS, None],
587
                            default=EngineArgs.quantization,
588
589
590
591
592
593
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
594
595
596
597
598
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
599
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
600
601
602
603
604
605
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
606
607
608
609
610
611
612
613
614
615
        parser.add_argument(
            '--hf-token',
            type=str,
            nargs='?',
            const=True,
            default=None,
            help='The token to use as HTTP bearer authorization'
            ' for remote files. If `True`, will use the token '
            'generated when running `huggingface-cli login` '
            '(stored in `~/.huggingface`).')
616
617
618
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
619
                            help='Extra arguments for the HuggingFace config. '
620
621
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
622
623
624
625
626
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
627
        parser.add_argument('--max-seq-len-to-capture',
628
629
630
631
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
632
633
634
635
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
636
637
638
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
639
                            help='See ParallelConfig.')
640
641
642
643
644
645
646
647
648
649
650
651
652
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
653
                            type=nullable_str,
654
655
656
657
658
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
659
660
661
662
663
664
665

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
666
            # MultiModalConfig.get_limit_per_prompt
667
668
669
670
671
672
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
673
674
675
676
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
677
            help=('Overrides for the multimodal input mapping/processing, '
678
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
679
        parser.add_argument(
680
            '--disable-mm-preprocessor-cache',
681
            action='store_true',
682
683
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
684

685
686
687
688
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
689
690
691
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
711
            choices=['auto', 'float16', 'bfloat16'],
712
713
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
714
715
716
717
718
719
720
721
722
723
724
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
725
726
727
728
729
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
730
                  'Must be >= than max_loras.'))
731
732
733
734
735
736
737
738
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
739
740
741
742
743
744
745
746
747
748
749
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
750
751
752
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
753
                            choices=DEVICE_OPTIONS,
754
                            help='Device type for vLLM execution.')
755
756
757
758
759
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
760
761
762
763
764
765
766
767
        parser.add_argument(
            '--use-tqdm-on-load',
            dest='use_tqdm_on_load',
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.use_tqdm_on_load,
            help='Whether to enable/disable progress bar '
            'when loading model weights.',
        )
768

769
770
        parser.add_argument(
            '--multi-step-stream-outputs',
771
772
773
774
775
776
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
777
778
779
780
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
781
            help='Apply a delay (of delay factor multiplied by previous '
782
            'prompt latency) before scheduling next prompt.')
783
784
        parser.add_argument(
            '--enable-chunked-prefill',
785
786
787
788
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
789
            help='If set, the prefill requests can be chunked based on the '
790
            'max_num_batched_tokens.')
791
        parser.add_argument('--speculative-config',
792
                            type=json.loads,
793
794
795
                            default=None,
                            help='The configurations for speculative decoding.'
                            ' Should be a JSON string.')
796

797
        parser.add_argument('--model-loader-extra-config',
798
                            type=nullable_str,
799
800
801
802
803
804
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
805
806
807
808
809
810
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
811
            "Default to `original/**/*` to avoid repeated loading of llama's "
812
            "checkpoints.")
813
        parser.add_argument(
814
            '--preemption-mode',
815
816
            type=str,
            default=None,
817
818
819
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
820

821
822
823
824
825
826
827
828
829
830
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
831
            "same as the ``--model`` argument. Noted that this name(s) "
832
            "will also be used in `model_name` tag content of "
833
            "prometheus metrics, if multiple names provided, metrics "
834
            "tag will take the first one.")
835
836
837
838
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
839

840
841
842
843
844
845
846
847
848
849
850
851
        parser.add_argument('--show-hidden-metrics-for-version',
                            type=str,
                            default=None,
                            help='Enable deprecated Prometheus metrics that '
                            'have been hidden since the specified version. '
                            'For example, if a previously deprecated metric '
                            'has been hidden since the v0.7.0 release, you '
                            'use --show-hidden-metrics-for-version=0.7 as a '
                            'temporary escape hatch while you migrate to new '
                            'metrics. The metric is likely to be removed '
                            'completely in an upcoming release.')

852
853
854
855
856
        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
857
858
859
860
861
862
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
863
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
864
865
866
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
867

868
869
870
871
872
873
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
874

875
876
877
878
879
880
881
882
883
884
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

885
886
887
888
889
890
891
        parser.add_argument(
            '--scheduler-cls',
            default=EngineArgs.scheduler_cls,
            help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
            'is the default scheduler. Can be a class directly or the path to '
            'a class of form "mod.custom_class".')

892
        parser.add_argument(
893
894
            '--override-neuron-config',
            type=json.loads,
895
            default=None,
896
            help="Override or set neuron device configuration. "
897
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
898
        parser.add_argument(
899
900
            '--override-pooler-config',
            type=PoolerConfig.from_json,
901
            default=None,
902
            help="Override or set the pooling method for pooling models. "
903
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
904

905
906
907
908
909
910
911
912
913
914
915
916
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
917
918
919
920
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
921

922
923
924
925
926
927
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

928
929
930
931
932
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
933
934
935
936
937
938
939
        parser.add_argument(
            '--worker-extension-cls',
            type=str,
            default="",
            help='The worker extension class on top of the worker cls, '
            'it is useful if you just want to add new functions to the worker '
            'class without changing the existing functions.')
940
941
942
        parser.add_argument(
            "--generation-config",
            type=nullable_str,
943
            default="auto",
944
            help="The folder path to the generation config. "
945
946
947
948
949
            "Defaults to 'auto', the generation config will be loaded from "
            "model path. If set to 'vllm', no generation config is loaded, "
            "vLLM defaults will be used. If set to a folder path, the "
            "generation config will be loaded from the specified folder path. "
            "If `max_new_tokens` is specified in generation config, then "
950
951
952
953
954
955
956
957
958
959
960
961
            "it sets a server-wide limit on the number of output tokens "
            "for all requests.")

        parser.add_argument(
            "--override-generation-config",
            type=json.loads,
            default=None,
            help="Overrides or sets generation config in JSON format. "
            "e.g. ``{\"temperature\": 0.5}``. If used with "
            "--generation-config=auto, the override parameters will be merged "
            "with the default config from the model. If generation-config is "
            "None, only the override parameters are used.")
962

963
964
965
966
967
968
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

969
970
971
972
973
974
975
976
977
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')

978
979
980
981
982
983
984
985
        parser.add_argument(
            "--additional-config",
            type=json.loads,
            default=None,
            help="Additional config for specified platform in JSON format. "
            "Different platforms may support different configs. Make sure the "
            "configs are valid for the platform you are using. The input format"
            " is like '{\"config_key\":\"config_value\"}'")
986
987
988
989
990
991
992
993
994
995
996
997

        parser.add_argument(
            "--enable-reasoning",
            action="store_true",
            default=False,
            help="Whether to enable reasoning_content for the model. "
            "If enabled, the model will be able to generate reasoning content."
        )

        parser.add_argument(
            "--reasoning-parser",
            type=str,
998
            choices=list(ReasoningParserManager.reasoning_parsers),
999
1000
1001
1002
1003
1004
            default=None,
            help=
            "Select the reasoning parser depending on the model that you're "
            "using. This is used to parse the reasoning content into OpenAI "
            "API format. Required for ``--enable-reasoning``.")

1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
        parser.add_argument(
            "--disable-cascade-attn",
            action="store_true",
            default=False,
            help="Disable cascade attention for V1. While cascade attention "
            "does not change the mathematical correctness, disabling it "
            "could be useful for preventing potential numerical issues. "
            "Note that even if this is set to False, cascade attention will be "
            "only used when the heuristic tells that it's beneficial.")

1015
        return parser
1016
1017

    @classmethod
1018
    def from_cli_args(cls, args: argparse.Namespace):
1019
1020
1021
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
1022
1023
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
1024

1025
    def create_model_config(self) -> ModelConfig:
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

1037
        return ModelConfig(
1038
            model=self.model,
1039
            hf_config_path=self.hf_config_path,
1040
            task=self.task,
1041
1042
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
1043
1044
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1045
            allowed_local_media_path=self.allowed_local_media_path,
1046
1047
1048
1049
1050
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1051
            rope_theta=self.rope_theta,
1052
            hf_token=self.hf_token,
1053
            hf_overrides=self.hf_overrides,
1054
1055
1056
1057
1058
1059
1060
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
1061
            disable_cascade_attn=self.disable_cascade_attn,
1062
            skip_tokenizer_init=self.skip_tokenizer_init,
1063
            served_model_name=self.served_model_name,
1064
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1065
            use_async_output_proc=not self.disable_async_output_proc,
1066
            config_format=self.config_format,
1067
            mm_processor_kwargs=self.mm_processor_kwargs,
1068
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1069
1070
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1071
            logits_processor_pattern=self.logits_processor_pattern,
1072
            generation_config=self.generation_config,
1073
            override_generation_config=self.override_generation_config,
1074
            enable_sleep_mode=self.enable_sleep_mode,
1075
            model_impl=self.model_impl,
1076
        )
1077

1078
1079
    def create_load_config(self) -> LoadConfig:

1080
        if(self.qlora_adapter_name_or_path is not None) and \
1081
1082
            self.quantization != "bitsandbytes":
            raise ValueError(
1083
                "QLoRA adapter only support "
1084
1085
                f"'bitsandbytes' quantization, but got {self.quantization}")

1086
1087
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1088
1089
1090
1091
1092
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1093
            use_tqdm_on_load=self.use_tqdm_on_load,
1094
        )
1095

1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1109
        dictionary from the engine.
1110
1111
        """
        if self.speculative_config is None:
1112
1113
            return None

1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
        speculative_config = SpeculativeConfig.from_dict(
            self.speculative_config)

        return speculative_config

1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1138

1139
1140
1141
1142
1143
1144
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1145
1146
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update()
1147

1148
        device_config = DeviceConfig(device=self.device)
1149
1150
        model_config = self.create_model_config()

1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
            self._set_default_args_v1(usage_context)
        else:
            self._set_default_args_v0(model_config)
1173

1174
1175
        assert self.enable_chunked_prefill is not None

1176
        cache_config = CacheConfig(
1177
            block_size=self.block_size,
1178
1179
1180
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1181
            is_attention_free=model_config.is_attention_free,
1182
1183
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1184
            enable_prefix_caching=self.enable_prefix_caching,
1185
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1186
            cpu_offload_gb=self.cpu_offload_gb,
1187
            calculate_kv_scales=self.calculate_kv_scales,
1188
        )
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1201
        parallel_config = ParallelConfig(
1202
1203
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1204
            data_parallel_size=self.data_parallel_size,
1205
            enable_expert_parallel=self.enable_expert_parallel,
1206
1207
1208
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1209
1210
1211
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1212
            ),
1213
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1214
            placement_group=placement_group,
1215
1216
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1217
            worker_extension_cls=self.worker_extension_cls,
1218
        )
1219

1220
        speculative_config = self.create_speculative_config(
1221
1222
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1223
            enable_chunked_prefill=self.enable_chunked_prefill,
1224
            disable_log_stats=self.disable_log_stats,
1225
1226
        )

1227
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1228
        # If the feature combo become valid
1229
1230
1231
1232
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1233
1234
1235
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1236
1237
1238
1239
1240
1241
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1242
1243
1244
1245
1246
1247
1248
1249
1250

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1251
        scheduler_config = SchedulerConfig(
1252
            runner_type=model_config.runner_type,
1253
1254
1255
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1256
            num_lookahead_slots=num_lookahead_slots,
1257
1258
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1259
            is_multimodal_model=model_config.is_multimodal_model,
1260
            preemption_mode=self.preemption_mode,
1261
            num_scheduler_steps=self.num_scheduler_steps,
1262
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1263
1264
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1265
            policy=self.scheduling_policy,
1266
            scheduler_cls=self.scheduler_cls,
1267
1268
1269
1270
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
        )
1271

1272
        lora_config = LoRAConfig(
1273
            bias_enabled=self.enable_lora_bias,
1274
1275
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1276
            fully_sharded_loras=self.fully_sharded_loras,
1277
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1278
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1279
1280
1281
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1282

1283
1284
1285
1286
1287
1288
1289
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1290
1291
1292
1293
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1294
        load_config = self.create_load_config()
1295

1296
1297
1298
1299
1300
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1301
        decoding_config = DecodingConfig(
1302
1303
1304
1305
            guided_decoding_backend=self.guided_decoding_backend,
            reasoning_backend=self.reasoning_parser
            if self.enable_reasoning else None,
        )
1306

1307
1308
1309
1310
1311
        show_hidden_metrics = False
        if self.show_hidden_metrics_for_version is not None:
            show_hidden_metrics = version._prev_minor_version_was(
                self.show_hidden_metrics_for_version)

1312
1313
1314
1315
1316
1317
1318
1319
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1320
        observability_config = ObservabilityConfig(
1321
            show_hidden_metrics=show_hidden_metrics,
1322
1323
1324
1325
1326
1327
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1328

1329
        config = VllmConfig(
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1340
            prompt_adapter_config=prompt_adapter_config,
1341
            compilation_config=self.compilation_config,
1342
            kv_transfer_config=self.kv_transfer_config,
1343
            additional_config=self.additional_config,
1344
        )
1345

1346
1347
        return config

1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

        if (self.load_format == LoadFormat.TENSORIZER.value
                or self.load_format == LoadFormat.SHARDED_STATE.value):
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

        if self.preemption_mode != EngineArgs.preemption_mode:
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

        if self.scheduling_policy != EngineArgs.scheduling_policy:
            _raise_or_fallback(feature_name="--scheduling-policy",
                               recommend_to_remove=False)
            return False

        if self.num_scheduler_steps != EngineArgs.num_scheduler_steps:
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

        if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor:
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

        if self.additional_config != EngineArgs.additional_config:
            _raise_or_fallback(feature_name="--additional-config",
                               recommend_to_remove=False)
            return False

1398
        # Xgrammar and Guidance are supported.
1399
        SUPPORTED_GUIDED_DECODING = [
1400
1401
            "xgrammar", "xgrammar:disable-any-whitespace", "guidance",
            "guidance:disable-any-whitespace", "auto"
1402
        ]
1403
1404
1405
1406
1407
1408
        if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
            _raise_or_fallback(feature_name="--guided-decoding-backend",
                               recommend_to_remove=False)
            return False

        # Need at least Ampere for now (FA support required).
1409
1410
1411
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1412
1413
        from vllm.platforms import current_platform
        if (current_platform.is_cuda()
1414
                and current_platform.get_device_capability()
1415
1416
1417
1418
1419
1420
1421
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
            fp8_attention = self.kv_cache_dtype.startswith("fp8")
            will_use_fa = (
                current_platform.is_cuda()
                and not envs.is_set("VLLM_ATTENTION_BACKEND")
            ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
            supported = False
            if fp8_attention and will_use_fa:
                from vllm.vllm_flash_attn.fa_utils import (
                    flash_attn_supports_fp8)
                supported = flash_attn_supports_fp8()
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

        # Only Fp16 and Bf16 dtypes since we only support FA.
        V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
        if model_config.dtype not in V1_SUPPORTED_DTYPES:
            _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
                               recommend_to_remove=False)
            return False

        # Some quantization is not compatible with torch.compile.
1451
        V1_UNSUPPORTED_QUANT = ["gguf"]
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
        if model_config.quantization in V1_UNSUPPORTED_QUANT:
            _raise_or_fallback(
                feature_name=f"--quantization {model_config.quantization}",
                recommend_to_remove=False)
            return False

        # No Embedding Models so far.
        if model_config.task not in ["generate"]:
            _raise_or_fallback(feature_name=f"--task {model_config.task}",
                               recommend_to_remove=False)
            return False

        # No Mamba or Encoder-Decoder so far.
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
                != EngineArgs.max_num_partial_prefills
                or self.max_long_partial_prefills
1474
                != EngineArgs.max_long_partial_prefills):
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

        # Only Ngram speculative decoding so far.
1486
        is_ngram_enabled = False
1487
        is_eagle_enabled = False
1488
        if self.speculative_config is not None:
1489
            # This is supported but experimental (handled below).
1490
1491
1492
1493
1494
1495
            speculative_method = self.speculative_config.get("method")
            if speculative_method:
                if speculative_method in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
                elif speculative_method == "eagle":
                    is_eagle_enabled = True
1496
            else:
1497
1498
1499
1500
1501
                speculative_model = self.speculative_config.get("model")
                if speculative_model in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
            if not (is_ngram_enabled or is_eagle_enabled):
                # Other speculative decoding methods are not supported yet.
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

        # No Disaggregated Prefill so far.
        if self.kv_transfer_config != EngineArgs.kv_transfer_config:
            _raise_or_fallback(feature_name="--kv-transfer-config",
                               recommend_to_remove=False)
            return False

        # No FlashInfer or XFormers so far.
        V1_BACKENDS = [
            "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
1515
            "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
1516
1517
1518
1519
1520
1521
1522
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1523
1524
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1525
1526
1527
1528
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1529
1530
1531
        #############################################################
        # Experimental Features - allow users to opt in.

1532
1533
1534
1535
1536
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1537
1538
1539
        # PP is supported on V1 with Ray distributed executor,
        # but off for MP distributed executor for now.
        if (self.pipeline_parallel_size > 1
1540
1541
1542
                and self.distributed_executor_backend != "ray"):
            name = "Pipeline Parallelism without Ray distributed executor"
            _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1543
1544
1545
            return False

        # ngram is supported on V1, but off by default for now.
1546
        if is_ngram_enabled and _warn_or_fallback("ngram"):
1547
1548
            return False

1549
1550
1551
1552
        # Eagle is under development, so we don't support it yet.
        if is_eagle_enabled and _warn_or_fallback("Eagle"):
            return False

1553
1554
1555
        # Non-CUDA is supported on V1, but off by default for now.
        not_cuda = not current_platform.is_cuda()
        if not_cuda and _warn_or_fallback(  # noqa: SIM103
1556
                current_platform.device_name):
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
            return False
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                from vllm.platforms import current_platform
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1579
                use_spec_decode = self.speculative_config is not None
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching:
            # Disable prefix caching for multimodal models for VLLM_V0.
            if model_config.is_multimodal_model:
                logger.warning(
                    "--enable-prefix-caching is not supported for multimodal "
                    "models in V0 and has been disabled.")
                self.enable_prefix_caching = False

            # VLLM_V0 only supports builtin hash algo for prefix caching.
            if self.prefix_caching_hash_algo is None:
                self.prefix_caching_hash_algo = "builtin"
            elif self.prefix_caching_hash_algo == "sha256":
                raise ValueError(
                    "sha256 is not supported for prefix caching in V0 engine. "
                    "Please use 'builtin'.")
1623
1624
1625
1626
1627
1628
1629

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

    def _set_default_args_v1(self, usage_context: UsageContext) -> None:
        """Set Default Arguments for V1 Engine."""
1630

1631
1632
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
1633
1634
1635
1636
1637

        # V1 enables prefix caching by default.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = True

1638
1639
1640
1641
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
            self.prefix_caching_hash_algo = "builtin"

1642
1643
1644
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1645
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1646

1647
1648
        # When no user override, set the default values based on the usage
        # context.
1649
        # Use different default values for different hardware.
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
            from vllm.platforms import current_platform
            device_name = current_platform.get_device_name().lower()
        except Exception:
            # This is only used to set default_max_num_batched_tokens
            device_name = "no-device"

1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
        if "h100" in device_name or "h200" in device_name:
            # For H100 and H200, we use larger default values.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }

1676
        use_context_value = usage_context.value if usage_context else None
1677
1678
1679
1680
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1681
            logger.debug(
1682
                "Setting max_num_batched_tokens to %d for %s usage context.",
1683
                self.max_num_batched_tokens, use_context_value)
1684

1685
1686
1687
1688
1689
1690
        default_max_num_seqs = 1024
        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1691

1692

1693
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1694
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1695
    """Arguments for asynchronous vLLM engine."""
1696
    disable_log_requests: bool = False
1697
1698

    @staticmethod
1699
1700
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1701
1702
1703
1704
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1705
1706
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1707
1708
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1709
                            help='Disable logging requests.')
1710
1711
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update(parser)
1712
        return parser
1713
1714


1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1742
1743
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1744
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1745
1746
1747


def _async_engine_args_parser():
1748
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1749
                                        async_args_only=True)