arg_utils.py 61.9 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
6
                    Tuple, Type, Union, cast, get_args)
7

8
9
import torch

10
import vllm.envs as envs
11
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
12
13
14
15
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
                         ModelConfig, ObservabilityConfig, ParallelConfig,
                         PoolerConfig, PromptAdapterConfig, SchedulerConfig,
16
17
                         SpeculativeConfig, TaskOption, TokenizerPoolConfig,
                         VllmConfig)
18
from vllm.executor.executor_base import ExecutorBase
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
21
from vllm.transformers_utils.utils import check_gguf_file
22
from vllm.usage.usage_lib import UsageContext
23
from vllm.utils import FlexibleArgumentParser, StoreBoolean
24

25
if TYPE_CHECKING:
26
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
27

28
29
logger = init_logger(__name__)

30
31
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

32
33
34
35
36
37
38
39
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
40
    "hpu",
41
42
]

43

44
45
46
47
48
49
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


50
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
51
52
53
54
55
56
57
58
59
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
60
61
62
63
64
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
65
66
67
68
69
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
70
71

        try:
72
            parsed_value = int(value)
73
74
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
75
76
77
78
79
80
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
81
82
83
84

    return out_dict


85
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
86
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
87
    """Arguments for vLLM engine."""
88
    model: str = 'facebook/opt-125m'
89
    served_model_name: Optional[Union[str, List[str]]] = None
90
    tokenizer: Optional[str] = None
91
    task: TaskOption = "auto"
92
    skip_tokenizer_init: bool = False
93
    tokenizer_mode: str = 'auto'
94
    trust_remote_code: bool = False
95
    allowed_local_media_path: str = ""
96
    download_dir: Optional[str] = None
97
    load_format: str = 'auto'
98
    config_format: ConfigFormat = ConfigFormat.AUTO
99
    dtype: str = 'auto'
100
    kv_cache_dtype: str = 'auto'
101
    seed: int = 0
102
    max_model_len: Optional[int] = None
103
104
105
106
107
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
108
    # number of P/D disaggregation (or other disaggregation) workers
109
110
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
111
    max_parallel_loading_workers: Optional[int] = None
112
    block_size: Optional[int] = None
113
    enable_prefix_caching: Optional[bool] = None
114
    disable_sliding_window: bool = False
115
    use_v2_block_manager: bool = True
116
117
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
118
    gpu_memory_utilization: float = 0.90
119
    max_num_batched_tokens: Optional[int] = None
120
    max_num_seqs: Optional[int] = None
121
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
122
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
123
    revision: Optional[str] = None
124
    code_revision: Optional[str] = None
125
    rope_scaling: Optional[Dict[str, Any]] = None
126
    rope_theta: Optional[float] = None
127
    hf_overrides: Optional[HfOverrides] = None
128
    tokenizer_revision: Optional[str] = None
129
    quantization: Optional[str] = None
130
    enforce_eager: Optional[bool] = None
131
    max_seq_len_to_capture: int = 8192
132
    disable_custom_all_reduce: bool = False
133
    tokenizer_pool_size: int = 0
134
135
136
137
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
138
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
139
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
140
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
141
    disable_mm_preprocessor_cache: bool = False
142
    enable_lora: bool = False
143
    enable_lora_bias: bool = False
144
145
    max_loras: int = 1
    max_lora_rank: int = 16
146
147
148
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
149
    fully_sharded_loras: bool = False
150
    lora_extra_vocab_size: int = 256
151
    long_lora_scaling_factors: Optional[Tuple[float]] = None
152
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
153
    max_cpu_loras: Optional[int] = None
154
    device: str = 'auto'
155
    num_scheduler_steps: int = 1
156
    multi_step_stream_outputs: bool = True
157
    ray_workers_use_nsight: bool = False
158
    num_gpu_blocks_override: Optional[int] = None
159
    num_lookahead_slots: int = 0
160
    model_loader_extra_config: Optional[dict] = None
161
    ignore_patterns: Optional[Union[str, List[str]]] = None
162
    preemption_mode: Optional[str] = None
163

164
    scheduler_delay_factor: float = 0.0
165
    enable_chunked_prefill: Optional[bool] = None
166

167
    guided_decoding_backend: str = 'xgrammar'
168
    logits_processor_pattern: Optional[str] = None
169
170
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
171
    speculative_model_quantization: Optional[str] = None
172
    speculative_draft_tensor_parallel_size: Optional[int] = None
173
    num_speculative_tokens: Optional[int] = None
174
    speculative_disable_mqa_scorer: Optional[bool] = False
175
    speculative_max_model_len: Optional[int] = None
176
    speculative_disable_by_batch_size: Optional[int] = None
177
178
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
179
180
181
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
182
    qlora_adapter_name_or_path: Optional[str] = None
183
    disable_logprobs_during_spec_decoding: Optional[bool] = None
184

185
    otlp_traces_endpoint: Optional[str] = None
186
    collect_detailed_traces: Optional[str] = None
187
    disable_async_output_proc: bool = False
188
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
189

190
191
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
192
    compilation_config: Optional[CompilationConfig] = None
193
    worker_cls: str = "auto"
194

195
196
    kv_transfer_config: Optional[KVTransferConfig] = None

197
    generation_config: Optional[str] = None
198
    override_generation_config: Optional[Dict[str, Any]] = None
199
    enable_sleep_mode: bool = False
200

201
202
    calculate_kv_scales: Optional[bool] = None

203
    def __post_init__(self):
204
        if not self.tokenizer:
205
            self.tokenizer = self.model
206

207
208
209
210
        # Override the default value of enable_prefix_caching if it's not set
        # by user.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
211

212
213
214
        # Override max_num_seqs if it's not set by user.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
215

216
217
218
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
219
        if isinstance(self.compilation_config, (int, dict)):
220
221
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
222

223
        # Setup plugins
224
225
        from vllm.plugins import load_general_plugins
        load_general_plugins()
226
227

    @staticmethod
228
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
229
        """Shared CLI arguments for vLLM engine."""
230

231
        # Model arguments
232
233
234
        parser.add_argument(
            '--model',
            type=str,
235
            default=EngineArgs.model,
236
            help='Name or path of the huggingface model to use.')
237
238
239
240
241
242
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
243
            'multiple tasks. When the model only supports one task, ``"auto"`` '
244
245
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
246
247
        parser.add_argument(
            '--tokenizer',
248
            type=nullable_str,
249
            default=EngineArgs.tokenizer,
250
251
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
252
253
254
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
255
            help='Skip initialization of tokenizer and detokenizer.')
Jasmond L's avatar
Jasmond L committed
256
257
        parser.add_argument(
            '--revision',
258
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
259
            default=None,
260
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
261
262
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
263
264
        parser.add_argument(
            '--code-revision',
265
            type=nullable_str,
266
            default=None,
267
            help='The specific revision to use for the model code on '
268
269
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
270
271
        parser.add_argument(
            '--tokenizer-revision',
272
            type=nullable_str,
273
            default=None,
274
275
276
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
277
278
279
280
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
281
            choices=['auto', 'slow', 'mistral'],
282
283
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
284
285
            'always use the slow tokenizer. \n* '
            '"mistral" will always use the `mistral_common` tokenizer.')
286
287
        parser.add_argument('--trust-remote-code',
                            action='store_true',
288
                            help='Trust remote code from huggingface.')
289
290
291
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
292
293
294
295
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
296
        parser.add_argument('--download-dir',
297
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
298
                            default=EngineArgs.download_dir,
299
                            help='Directory to download and load the weights, '
300
                            'default to the default cache dir of '
301
                            'huggingface.')
302
303
304
305
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
306
            choices=[f.value for f in LoadFormat],
307
308
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
309
            'and fall back to the pytorch bin format if safetensors format '
310
311
312
313
314
315
316
317
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
318
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
319
            'section for more information.\n'
320
321
            '* "runai_streamer" will load the Safetensors weights using Run:ai'
            'Model Streamer \n'
322
323
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
324
325
326
327
328
329
330
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
331
332
333
334
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
338
339
340
341
342
343
344
345
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
346
347
348
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
349
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
350
            default=EngineArgs.kv_cache_dtype,
351
            help='Data type for kv cache storage. If "auto", will use model '
352
353
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
354
355
        parser.add_argument('--max-model-len',
                            type=int,
356
                            default=EngineArgs.max_model_len,
357
358
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
359
360
361
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
362
363
            default='xgrammar',
            choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
364
            help='Which engine will be used for guided decoding'
365
            ' (JSON schema / regex etc) by default. Currently support '
366
            'https://github.com/outlines-dev/outlines, '
367
            'https://github.com/mlc-ai/xgrammar, and '
368
369
370
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
371
372
373
374
375
376
377
378
        parser.add_argument(
            '--logits-processor-pattern',
            type=nullable_str,
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
379
        # Parallel arguments
380
381
        parser.add_argument(
            '--distributed-executor-backend',
382
            choices=['ray', 'mp', 'uni', 'external_launcher'],
383
            default=EngineArgs.distributed_executor_backend,
384
385
386
387
388
389
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
390
            'only supports Ray for distributed inference.')
391

392
393
394
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
395
                            default=EngineArgs.pipeline_parallel_size,
396
                            help='Number of pipeline stages.')
397
398
399
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
400
                            default=EngineArgs.tensor_parallel_size,
401
                            help='Number of tensor parallel replicas.')
402
403
404
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
405
            default=EngineArgs.max_parallel_loading_workers,
406
            help='Load model sequentially in multiple batches, '
407
            'to avoid RAM OOM when using tensor '
408
            'parallel and large models.')
409
410
411
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
412
            help='If specified, use nsight to profile Ray workers.')
413
        # KV cache arguments
414
415
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
416
                            default=EngineArgs.block_size,
417
                            choices=[8, 16, 32, 64, 128],
418
                            help='Token block size for contiguous chunks of '
419
                            'tokens. This is ignored on neuron devices and '
420
                            'set to ``--max-model-len``. On CUDA devices, '
421
422
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
423

424
425
426
427
428
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
429
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
430
        )
431
432
433
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
434
                            'capping to sliding window size.')
435
436
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
437
                            default=True,
438
439
440
441
442
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
443
444
445
446
447
448
449
450
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
451

452
453
454
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
455
                            help='Random seed for operations.')
456
        parser.add_argument('--swap-space',
457
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
458
                            default=EngineArgs.swap_space,
459
                            help='CPU swap space size (GiB) per GPU.')
460
461
462
463
464
465
466
467
468
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
469
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
470
            'which requires at least 26GB GPU memory. Note that this '
471
            'requires fast CPU-GPU interconnect, as part of the model is '
472
473
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
474
475
476
477
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
478
479
480
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
481
482
483
484
485
486
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
487
        parser.add_argument(
488
            '--num-gpu-blocks-override',
489
490
491
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
492
            ' of GPU blocks. Used for testing preemption.')
493
494
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
495
                            default=EngineArgs.max_num_batched_tokens,
496
497
                            help='Maximum number of batched tokens per '
                            'iteration.')
498
499
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
500
                            default=EngineArgs.max_num_seqs,
501
                            help='Maximum number of sequences per iteration.')
502
503
504
505
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
506
507
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
508
509
        parser.add_argument('--disable-log-stats',
                            action='store_true',
510
                            help='Disable logging statistics.')
511
512
513
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
514
                            type=nullable_str,
515
                            choices=[*QUANTIZATION_METHODS, None],
516
                            default=EngineArgs.quantization,
517
518
519
520
521
522
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
523
524
525
526
527
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
528
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
529
530
531
532
533
534
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
535
536
537
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
538
                            help='Extra arguments for the HuggingFace config. '
539
540
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
541
542
543
544
545
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
546
        parser.add_argument('--max-seq-len-to-capture',
547
548
549
550
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
551
552
553
554
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
555
556
557
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
558
                            help='See ParallelConfig.')
559
560
561
562
563
564
565
566
567
568
569
570
571
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
572
                            type=nullable_str,
573
574
575
576
577
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
578
579
580
581
582
583
584
585
586
587
588
589
590
591

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
592
593
594
595
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
596
            help=('Overrides for the multimodal input mapping/processing, '
597
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
598
        parser.add_argument(
599
            '--disable-mm-preprocessor-cache',
600
            action='store_true',
601
602
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
603

604
605
606
607
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
608
609
610
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
630
            choices=['auto', 'float16', 'bfloat16'],
631
632
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
633
634
635
636
637
638
639
640
641
642
643
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
644
645
646
647
648
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
649
650
                  'Must be >= than max_loras. '
                  'Defaults to max_loras.'))
651
652
653
654
655
656
657
658
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
659
660
661
662
663
664
665
666
667
668
669
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
670
671
672
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
673
                            choices=DEVICE_OPTIONS,
674
                            help='Device type for vLLM execution.')
675
676
677
678
679
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
680

681
682
        parser.add_argument(
            '--multi-step-stream-outputs',
683
684
685
686
687
688
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
689
690
691
692
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
693
            help='Apply a delay (of delay factor multiplied by previous '
694
            'prompt latency) before scheduling next prompt.')
695
696
        parser.add_argument(
            '--enable-chunked-prefill',
697
698
699
700
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
701
            help='If set, the prefill requests can be chunked based on the '
702
            'max_num_batched_tokens.')
703
704
705

        parser.add_argument(
            '--speculative-model',
706
            type=nullable_str,
707
            default=EngineArgs.speculative_model,
708
709
            help=
            'The name of the draft model to be used in speculative decoding.')
710
711
712
713
714
715
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
716
            help='Method used to quantize the weights of speculative model. '
717
718
719
720
721
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
722
723
724
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
725
            default=EngineArgs.num_speculative_tokens,
726
            help='The number of speculative tokens to sample from '
727
            'the draft model in speculative decoding.')
728
729
730
731
732
733
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
734
735
736
737
738
739
740
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
741

742
743
        parser.add_argument(
            '--speculative-max-model-len',
744
            type=int,
745
746
747
748
749
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

750
751
752
753
754
755
756
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

757
758
759
760
761
762
763
764
765
766
767
768
769
770
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

803
804
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
805
            action=StoreBoolean,
806
            default=EngineArgs.disable_logprobs_during_spec_decoding,
807
808
            nargs="?",
            const="True",
809
810
811
812
813
814
815
816
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

817
        parser.add_argument('--model-loader-extra-config',
818
                            type=nullable_str,
819
820
821
822
823
824
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
825
826
827
828
829
830
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
831
            "Default to `original/**/*` to avoid repeated loading of llama's "
832
            "checkpoints.")
833
        parser.add_argument(
834
            '--preemption-mode',
835
836
            type=str,
            default=None,
837
838
839
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
840

841
842
843
844
845
846
847
848
849
850
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
851
            "same as the ``--model`` argument. Noted that this name(s) "
852
            "will also be used in `model_name` tag content of "
853
            "prometheus metrics, if multiple names provided, metrics "
854
            "tag will take the first one.")
855
856
857
858
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
859
860
861
862
863
864

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
865
866
867
868
869
870
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
871
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
872
873
874
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
875

876
877
878
879
880
881
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
882

883
884
885
886
887
888
889
890
891
892
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

893
        parser.add_argument(
894
895
            '--override-neuron-config',
            type=json.loads,
896
            default=None,
897
            help="Override or set neuron device configuration. "
898
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
899
        parser.add_argument(
900
901
            '--override-pooler-config',
            type=PoolerConfig.from_json,
902
            default=None,
903
            help="Override or set the pooling method for pooling models. "
904
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
905

906
907
908
909
910
911
912
913
914
915
916
917
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
918
919
920
921
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
922

923
924
925
926
927
928
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

929
930
931
932
933
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
934
935
936
937
938
        parser.add_argument(
            "--generation-config",
            type=nullable_str,
            default=None,
            help="The folder path to the generation config. "
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
            "Defaults to None, no generation config is loaded, vLLM defaults "
            "will be used. If set to 'auto', the generation config will be "
            "loaded from model path. If set to a folder path, the generation "
            "config will be loaded from the specified folder path. If "
            "`max_new_tokens` is specified in generation config, then "
            "it sets a server-wide limit on the number of output tokens "
            "for all requests.")

        parser.add_argument(
            "--override-generation-config",
            type=json.loads,
            default=None,
            help="Overrides or sets generation config in JSON format. "
            "e.g. ``{\"temperature\": 0.5}``. If used with "
            "--generation-config=auto, the override parameters will be merged "
            "with the default config from the model. If generation-config is "
            "None, only the override parameters are used.")
956

957
958
959
960
961
962
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

963
964
965
966
967
968
969
970
971
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')

972
        return parser
973
974

    @classmethod
975
    def from_cli_args(cls, args: argparse.Namespace):
976
977
978
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
979
980
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
981

982
983
    def create_model_config(self) -> ModelConfig:
        return ModelConfig(
984
            model=self.model,
985
            task=self.task,
986
987
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
988
989
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
990
            allowed_local_media_path=self.allowed_local_media_path,
991
992
993
994
995
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
996
            rope_theta=self.rope_theta,
997
            hf_overrides=self.hf_overrides,
998
999
1000
1001
1002
1003
1004
1005
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
1006
            served_model_name=self.served_model_name,
1007
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1008
            use_async_output_proc=not self.disable_async_output_proc,
1009
            config_format=self.config_format,
1010
            mm_processor_kwargs=self.mm_processor_kwargs,
1011
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1012
1013
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1014
            logits_processor_pattern=self.logits_processor_pattern,
1015
            generation_config=self.generation_config,
1016
            override_generation_config=self.override_generation_config,
1017
1018
            enable_sleep_mode=self.enable_sleep_mode,
        )
1019

1020
1021
1022
1023
1024
1025
1026
1027
    def create_load_config(self) -> LoadConfig:
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
        )

1028
1029
1030
1031
1032
1033
    def create_engine_config(self,
                             usage_context: Optional[UsageContext] = None
                             ) -> VllmConfig:
        if envs.VLLM_USE_V1:
            self._override_v1_engine_args(usage_context)

1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

        device_config = DeviceConfig(device=self.device)
        model_config = self.create_model_config()

1061
1062
1063
1064
1065
        if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
                and self.enable_prefix_caching):
            logger.warning("--enable-prefix-caching is currently not "
                           "supported for multimodal models in v0 and "
                           "has been disabled.")
1066
1067
            self.enable_prefix_caching = False

1068
        cache_config = CacheConfig(
1069
            block_size=self.block_size,
1070
1071
1072
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1073
            is_attention_free=model_config.is_attention_free,
1074
1075
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1076
1077
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
1078
            calculate_kv_scales=self.calculate_kv_scales,
1079
        )
1080
        parallel_config = ParallelConfig(
1081
1082
1083
1084
1085
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1086
1087
1088
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1089
            ),
1090
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1091
1092
1093
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
        )
1094

1095
1096
1097
1098
1099
1100
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
1101

1102
1103
1104
1105
1106
1107
            # For multimodal models, chunked prefill is disabled by default in
            # V0, but enabled by design in V1
            if model_config.is_multimodal_model:
                self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)

            elif use_long_context:
1108
1109
1110
1111
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
1112
                from vllm.platforms import current_platform
1113
1114
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
1115
                        and not self.enable_prompt_adapter
1116
1117
                        and model_config.runner_type != "pooling"
                        and not current_platform.is_rocm()):
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
1134
1135
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
1136
            msg = "Chunked prefill is not supported for pooling models"
1137
            raise ValueError(msg)
1138

1139

1140
1141
1142
1143
1144
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
1145
1146
            speculative_model_quantization = \
                self.speculative_model_quantization,
1147
1148
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
1149
            num_speculative_tokens=self.num_speculative_tokens,
1150
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
1151
1152
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
1153
1154
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
1155
            disable_log_stats=self.disable_log_stats,
1156
1157
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1158
1159
1160
1161
1162
1163
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
1164
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
1165
1166
        )

1167
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1168
        # If the feature combo become valid
1169
1170
1171
1172
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1173
1174
1175
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1176
1177
1178
1179
1180
1181
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1182
1183
1184
1185
1186
1187
1188
1189
1190

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
        if not self.use_v2_block_manager:
            logger.warning(
                "[DEPRECATED] Block manager v1 has been removed, "
                "and setting --use-v2-block-manager to True or False has "
                "no effect on vLLM behavior. Please remove "
                "--use-v2-block-manager in your engine argument. "
                "If your use case is not supported by "
                "SelfAttnBlockSpaceManager (i.e. block manager v2),"
                " please file an issue with detailed information.")

1201
        scheduler_config = SchedulerConfig(
1202
            runner_type=model_config.runner_type,
1203
1204
1205
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1206
            num_lookahead_slots=num_lookahead_slots,
1207
1208
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1209
            is_multimodal_model=model_config.is_multimodal_model,
1210
            preemption_mode=self.preemption_mode,
1211
            num_scheduler_steps=self.num_scheduler_steps,
1212
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1213
1214
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1215
            policy=self.scheduling_policy)
1216
        lora_config = LoRAConfig(
1217
            bias_enabled=self.enable_lora_bias,
1218
1219
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1220
            fully_sharded_loras=self.fully_sharded_loras,
1221
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1222
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1223
1224
1225
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1226

1227
1228
1229
1230
1231
1232
1233
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1234
        load_config = self.create_load_config()
1235

1236
1237
1238
1239
1240
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1241
1242
1243
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

1244
1245
1246
1247
1248
1249
1250
1251
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1252
        observability_config = ObservabilityConfig(
1253
1254
1255
1256
1257
1258
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1259

1260
        config = VllmConfig(
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1271
            prompt_adapter_config=prompt_adapter_config,
1272
            compilation_config=self.compilation_config,
1273
            kv_transfer_config=self.kv_transfer_config,
1274
        )
1275

1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
        if envs.VLLM_USE_V1:
            self._override_v1_engine_config(config)
        return config

    def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
        """
        Override the EngineArgs's args based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

1286
1287
1288
1289
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
        # When no user override, set the default values based on the usage
        # context.
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
        # Use different default values for different hardware.
        from vllm.platforms import current_platform
        device_name = current_platform.get_device_name().lower()
        if "h100" in device_name or "h200" in device_name:
            # For H100 and H200, we use larger default values.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }

1306
1307
1308
1309
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1310
1311
1312
            logger.warning(
                "Setting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens, usage_context.value)
1313
1314
1315
1316
1317
1318
1319

    def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
        """
        Override the EngineConfig's configs based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

1320

1321
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1322
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1323
    """Arguments for asynchronous vLLM engine."""
1324
    disable_log_requests: bool = False
1325
1326

    @staticmethod
1327
1328
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1329
1330
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1331
1332
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1333
                            help='Disable logging requests.')
1334
        return parser
1335
1336
1337
1338


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1339
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1340
1341
1342


def _async_engine_args_parser():
1343
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1344
                                        async_args_only=True)