arg_utils.py 61.2 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
6
                    Tuple, Type, Union, cast, get_args)
7

8
9
import torch

10
import vllm.envs as envs
11
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
12
13
14
15
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
                         ModelConfig, ObservabilityConfig, ParallelConfig,
                         PoolerConfig, PromptAdapterConfig, SchedulerConfig,
16
17
                         SpeculativeConfig, TaskOption, TokenizerPoolConfig,
                         VllmConfig)
18
from vllm.executor.executor_base import ExecutorBase
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
21
from vllm.transformers_utils.utils import check_gguf_file
22
from vllm.usage.usage_lib import UsageContext
23
from vllm.utils import FlexibleArgumentParser, StoreBoolean
24

25
if TYPE_CHECKING:
26
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
27

28
29
logger = init_logger(__name__)

30
31
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

32
33
34
35
36
37
38
39
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
40
    "hpu",
41
42
]

43

44
45
46
47
48
49
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


50
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
51
52
53
54
55
56
57
58
59
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
60
61
62
63
64
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
65
66
67
68
69
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
70
71

        try:
72
            parsed_value = int(value)
73
74
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
75
76
77
78
79
80
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
81
82
83
84

    return out_dict


85
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
86
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
87
    """Arguments for vLLM engine."""
88
    model: str = 'facebook/opt-125m'
89
    served_model_name: Optional[Union[str, List[str]]] = None
90
    tokenizer: Optional[str] = None
91
    task: TaskOption = "auto"
92
    skip_tokenizer_init: bool = False
93
    tokenizer_mode: str = 'auto'
94
    trust_remote_code: bool = False
95
    allowed_local_media_path: str = ""
96
    download_dir: Optional[str] = None
97
    load_format: str = 'auto'
98
    config_format: ConfigFormat = ConfigFormat.AUTO
99
    dtype: str = 'auto'
100
    kv_cache_dtype: str = 'auto'
101
    seed: int = 0
102
    max_model_len: Optional[int] = None
103
104
105
106
107
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
108
    # number of P/D disaggregation (or other disaggregation) workers
109
110
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
111
    max_parallel_loading_workers: Optional[int] = None
112
    block_size: Optional[int] = None
113
    enable_prefix_caching: Optional[bool] = None
114
    disable_sliding_window: bool = False
115
    use_v2_block_manager: bool = True
116
117
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
118
    gpu_memory_utilization: float = 0.90
119
    max_num_batched_tokens: Optional[int] = None
120
    max_num_seqs: Optional[int] = None
121
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
122
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
123
    revision: Optional[str] = None
124
    code_revision: Optional[str] = None
125
    rope_scaling: Optional[Dict[str, Any]] = None
126
    rope_theta: Optional[float] = None
127
    hf_overrides: Optional[HfOverrides] = None
128
    tokenizer_revision: Optional[str] = None
129
    quantization: Optional[str] = None
130
    enforce_eager: Optional[bool] = None
131
    max_seq_len_to_capture: int = 8192
132
    disable_custom_all_reduce: bool = False
133
    tokenizer_pool_size: int = 0
134
135
136
137
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
138
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
139
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
140
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
141
    disable_mm_preprocessor_cache: bool = False
142
    enable_lora: bool = False
143
    enable_lora_bias: bool = False
144
145
    max_loras: int = 1
    max_lora_rank: int = 16
146
147
148
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
149
    fully_sharded_loras: bool = False
150
    lora_extra_vocab_size: int = 256
151
    long_lora_scaling_factors: Optional[Tuple[float]] = None
152
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
153
    max_cpu_loras: Optional[int] = None
154
    device: str = 'auto'
155
    num_scheduler_steps: int = 1
156
    multi_step_stream_outputs: bool = True
157
    ray_workers_use_nsight: bool = False
158
    num_gpu_blocks_override: Optional[int] = None
159
    num_lookahead_slots: int = 0
160
    model_loader_extra_config: Optional[dict] = None
161
    ignore_patterns: Optional[Union[str, List[str]]] = None
162
    preemption_mode: Optional[str] = None
163

164
    scheduler_delay_factor: float = 0.0
165
    enable_chunked_prefill: Optional[bool] = None
166

167
    guided_decoding_backend: str = 'xgrammar'
168
    logits_processor_pattern: Optional[str] = None
169
170
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
171
    speculative_model_quantization: Optional[str] = None
172
    speculative_draft_tensor_parallel_size: Optional[int] = None
173
    num_speculative_tokens: Optional[int] = None
174
    speculative_disable_mqa_scorer: Optional[bool] = False
175
    speculative_max_model_len: Optional[int] = None
176
    speculative_disable_by_batch_size: Optional[int] = None
177
178
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
179
180
181
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
182
    qlora_adapter_name_or_path: Optional[str] = None
183
    disable_logprobs_during_spec_decoding: Optional[bool] = None
184

185
    otlp_traces_endpoint: Optional[str] = None
186
    collect_detailed_traces: Optional[str] = None
187
    disable_async_output_proc: bool = False
188
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
189

190
191
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
192
    compilation_config: Optional[CompilationConfig] = None
193
    worker_cls: str = "auto"
194

195
196
    kv_transfer_config: Optional[KVTransferConfig] = None

197
    generation_config: Optional[str] = None
198
    enable_sleep_mode: bool = False
199

200
201
    calculate_kv_scales: Optional[bool] = None

202
    def __post_init__(self):
203
        if not self.tokenizer:
204
            self.tokenizer = self.model
205

206
207
208
209
        # Override the default value of enable_prefix_caching if it's not set
        # by user.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
210

211
212
213
        # Override max_num_seqs if it's not set by user.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
214

215
216
217
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
218
        if isinstance(self.compilation_config, (int, dict)):
219
220
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
221

222
        # Setup plugins
223
224
        from vllm.plugins import load_general_plugins
        load_general_plugins()
225
226

    @staticmethod
227
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
228
        """Shared CLI arguments for vLLM engine."""
229

230
        # Model arguments
231
232
233
        parser.add_argument(
            '--model',
            type=str,
234
            default=EngineArgs.model,
235
            help='Name or path of the huggingface model to use.')
236
237
238
239
240
241
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
242
            'multiple tasks. When the model only supports one task, ``"auto"`` '
243
244
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
245
246
        parser.add_argument(
            '--tokenizer',
247
            type=nullable_str,
248
            default=EngineArgs.tokenizer,
249
250
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
251
252
253
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
254
            help='Skip initialization of tokenizer and detokenizer.')
Jasmond L's avatar
Jasmond L committed
255
256
        parser.add_argument(
            '--revision',
257
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
258
            default=None,
259
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
260
261
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
262
263
        parser.add_argument(
            '--code-revision',
264
            type=nullable_str,
265
            default=None,
266
            help='The specific revision to use for the model code on '
267
268
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
269
270
        parser.add_argument(
            '--tokenizer-revision',
271
            type=nullable_str,
272
            default=None,
273
274
275
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
276
277
278
279
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
280
            choices=['auto', 'slow', 'mistral'],
281
282
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
283
284
            'always use the slow tokenizer. \n* '
            '"mistral" will always use the `mistral_common` tokenizer.')
285
286
        parser.add_argument('--trust-remote-code',
                            action='store_true',
287
                            help='Trust remote code from huggingface.')
288
289
290
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
291
292
293
294
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
295
        parser.add_argument('--download-dir',
296
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
297
                            default=EngineArgs.download_dir,
298
                            help='Directory to download and load the weights, '
299
                            'default to the default cache dir of '
300
                            'huggingface.')
301
302
303
304
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
305
            choices=[f.value for f in LoadFormat],
306
307
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
308
            'and fall back to the pytorch bin format if safetensors format '
309
310
311
312
313
314
315
316
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
317
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
318
            'section for more information.\n'
319
320
            '* "runai_streamer" will load the Safetensors weights using Run:ai'
            'Model Streamer \n'
321
322
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
323
324
325
326
327
328
329
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
330
331
332
333
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
336
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
337
338
339
340
341
342
343
344
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
345
346
347
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
348
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
349
            default=EngineArgs.kv_cache_dtype,
350
            help='Data type for kv cache storage. If "auto", will use model '
351
352
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
353
354
        parser.add_argument('--max-model-len',
                            type=int,
355
                            default=EngineArgs.max_model_len,
356
357
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
358
359
360
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
361
362
            default='xgrammar',
            choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
363
            help='Which engine will be used for guided decoding'
364
            ' (JSON schema / regex etc) by default. Currently support '
365
            'https://github.com/outlines-dev/outlines, '
366
            'https://github.com/mlc-ai/xgrammar, and '
367
368
369
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
370
371
372
373
374
375
376
377
        parser.add_argument(
            '--logits-processor-pattern',
            type=nullable_str,
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
378
        # Parallel arguments
379
380
        parser.add_argument(
            '--distributed-executor-backend',
381
            choices=['ray', 'mp', 'uni', 'external_launcher'],
382
            default=EngineArgs.distributed_executor_backend,
383
384
385
386
387
388
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
389
            'only supports Ray for distributed inference.')
390

391
392
393
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
394
                            default=EngineArgs.pipeline_parallel_size,
395
                            help='Number of pipeline stages.')
396
397
398
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
399
                            default=EngineArgs.tensor_parallel_size,
400
                            help='Number of tensor parallel replicas.')
401
402
403
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
404
            default=EngineArgs.max_parallel_loading_workers,
405
            help='Load model sequentially in multiple batches, '
406
            'to avoid RAM OOM when using tensor '
407
            'parallel and large models.')
408
409
410
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
411
            help='If specified, use nsight to profile Ray workers.')
412
        # KV cache arguments
413
414
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
415
                            default=EngineArgs.block_size,
416
                            choices=[8, 16, 32, 64, 128],
417
                            help='Token block size for contiguous chunks of '
418
                            'tokens. This is ignored on neuron devices and '
419
                            'set to ``--max-model-len``. On CUDA devices, '
420
421
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
422

423
424
425
426
427
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
428
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
429
        )
430
431
432
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
433
                            'capping to sliding window size.')
434
435
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
436
                            default=True,
437
438
439
440
441
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
442
443
444
445
446
447
448
449
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
450

451
452
453
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
454
                            help='Random seed for operations.')
455
        parser.add_argument('--swap-space',
456
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
457
                            default=EngineArgs.swap_space,
458
                            help='CPU swap space size (GiB) per GPU.')
459
460
461
462
463
464
465
466
467
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
468
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
469
            'which requires at least 26GB GPU memory. Note that this '
470
            'requires fast CPU-GPU interconnect, as part of the model is '
471
472
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
473
474
475
476
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
477
478
479
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
480
481
482
483
484
485
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
486
        parser.add_argument(
487
            '--num-gpu-blocks-override',
488
489
490
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
491
            ' of GPU blocks. Used for testing preemption.')
492
493
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
494
                            default=EngineArgs.max_num_batched_tokens,
495
496
                            help='Maximum number of batched tokens per '
                            'iteration.')
497
498
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
499
                            default=EngineArgs.max_num_seqs,
500
                            help='Maximum number of sequences per iteration.')
501
502
503
504
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
505
506
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
507
508
        parser.add_argument('--disable-log-stats',
                            action='store_true',
509
                            help='Disable logging statistics.')
510
511
512
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
513
                            type=nullable_str,
514
                            choices=[*QUANTIZATION_METHODS, None],
515
                            default=EngineArgs.quantization,
516
517
518
519
520
521
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
522
523
524
525
526
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
527
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
528
529
530
531
532
533
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
534
535
536
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
537
                            help='Extra arguments for the HuggingFace config. '
538
539
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
540
541
542
543
544
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
545
        parser.add_argument('--max-seq-len-to-capture',
546
547
548
549
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
550
551
552
553
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
554
555
556
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
557
                            help='See ParallelConfig.')
558
559
560
561
562
563
564
565
566
567
568
569
570
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
571
                            type=nullable_str,
572
573
574
575
576
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
577
578
579
580
581
582
583
584
585
586
587
588
589
590

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
591
592
593
594
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
595
            help=('Overrides for the multimodal input mapping/processing, '
596
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
597
        parser.add_argument(
598
            '--disable-mm-preprocessor-cache',
599
            action='store_true',
600
601
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
602

603
604
605
606
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
607
608
609
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
629
            choices=['auto', 'float16', 'bfloat16'],
630
631
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
632
633
634
635
636
637
638
639
640
641
642
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
643
644
645
646
647
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
648
649
                  'Must be >= than max_loras. '
                  'Defaults to max_loras.'))
650
651
652
653
654
655
656
657
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
658
659
660
661
662
663
664
665
666
667
668
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
669
670
671
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
672
                            choices=DEVICE_OPTIONS,
673
                            help='Device type for vLLM execution.')
674
675
676
677
678
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
679

680
681
        parser.add_argument(
            '--multi-step-stream-outputs',
682
683
684
685
686
687
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
688
689
690
691
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
692
            help='Apply a delay (of delay factor multiplied by previous '
693
            'prompt latency) before scheduling next prompt.')
694
695
        parser.add_argument(
            '--enable-chunked-prefill',
696
697
698
699
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
700
            help='If set, the prefill requests can be chunked based on the '
701
            'max_num_batched_tokens.')
702
703
704

        parser.add_argument(
            '--speculative-model',
705
            type=nullable_str,
706
            default=EngineArgs.speculative_model,
707
708
            help=
            'The name of the draft model to be used in speculative decoding.')
709
710
711
712
713
714
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
715
            help='Method used to quantize the weights of speculative model. '
716
717
718
719
720
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
721
722
723
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
724
            default=EngineArgs.num_speculative_tokens,
725
            help='The number of speculative tokens to sample from '
726
            'the draft model in speculative decoding.')
727
728
729
730
731
732
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
733
734
735
736
737
738
739
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
740

741
742
        parser.add_argument(
            '--speculative-max-model-len',
743
            type=int,
744
745
746
747
748
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

749
750
751
752
753
754
755
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

756
757
758
759
760
761
762
763
764
765
766
767
768
769
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

802
803
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
804
            action=StoreBoolean,
805
            default=EngineArgs.disable_logprobs_during_spec_decoding,
806
807
            nargs="?",
            const="True",
808
809
810
811
812
813
814
815
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

816
        parser.add_argument('--model-loader-extra-config',
817
                            type=nullable_str,
818
819
820
821
822
823
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
824
825
826
827
828
829
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
830
            "Default to `original/**/*` to avoid repeated loading of llama's "
831
            "checkpoints.")
832
        parser.add_argument(
833
            '--preemption-mode',
834
835
            type=str,
            default=None,
836
837
838
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
839

840
841
842
843
844
845
846
847
848
849
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
850
            "same as the ``--model`` argument. Noted that this name(s) "
851
            "will also be used in `model_name` tag content of "
852
            "prometheus metrics, if multiple names provided, metrics "
853
            "tag will take the first one.")
854
855
856
857
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
858
859
860
861
862
863

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
864
865
866
867
868
869
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
870
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
871
872
873
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
874

875
876
877
878
879
880
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
881

882
883
884
885
886
887
888
889
890
891
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

892
        parser.add_argument(
893
894
            '--override-neuron-config',
            type=json.loads,
895
            default=None,
896
            help="Override or set neuron device configuration. "
897
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
898
        parser.add_argument(
899
900
            '--override-pooler-config',
            type=PoolerConfig.from_json,
901
            default=None,
902
            help="Override or set the pooling method for pooling models. "
903
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
904

905
906
907
908
909
910
911
912
913
914
915
916
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
917
918
919
920
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
921

922
923
924
925
926
927
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

928
929
930
931
932
933
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')

934
935
936
937
938
939
940
941
        parser.add_argument(
            "--generation-config",
            type=nullable_str,
            default=None,
            help="The folder path to the generation config. "
            "Defaults to None, will use the default generation config in vLLM. "
            "If set to 'auto', the generation config will be automatically "
            "loaded from model. If set to a folder path, the generation config "
942
943
944
            "will be loaded from the specified folder path. If "
            "`max_new_tokens` is specified, then it sets a server-wide limit "
            "on the number of output tokens for all requests.")
945

946
947
948
949
950
951
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

952
953
954
955
956
957
958
959
960
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')

961
        return parser
962
963

    @classmethod
964
    def from_cli_args(cls, args: argparse.Namespace):
965
966
967
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
968
969
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
970

971
972
    def create_model_config(self) -> ModelConfig:
        return ModelConfig(
973
            model=self.model,
974
            task=self.task,
975
976
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
977
978
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
979
            allowed_local_media_path=self.allowed_local_media_path,
980
981
982
983
984
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
985
            rope_theta=self.rope_theta,
986
            hf_overrides=self.hf_overrides,
987
988
989
990
991
992
993
994
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
995
            served_model_name=self.served_model_name,
996
            limit_mm_per_prompt=self.limit_mm_per_prompt,
997
            use_async_output_proc=not self.disable_async_output_proc,
998
            config_format=self.config_format,
999
            mm_processor_kwargs=self.mm_processor_kwargs,
1000
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1001
1002
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1003
            logits_processor_pattern=self.logits_processor_pattern,
1004
1005
1006
            generation_config=self.generation_config,
            enable_sleep_mode=self.enable_sleep_mode,
        )
1007

1008
1009
1010
1011
1012
1013
1014
1015
    def create_load_config(self) -> LoadConfig:
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
        )

1016
1017
1018
1019
1020
1021
    def create_engine_config(self,
                             usage_context: Optional[UsageContext] = None
                             ) -> VllmConfig:
        if envs.VLLM_USE_V1:
            self._override_v1_engine_args(usage_context)

1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

        device_config = DeviceConfig(device=self.device)
        model_config = self.create_model_config()

1049
1050
1051
1052
1053
        if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
                and self.enable_prefix_caching):
            logger.warning("--enable-prefix-caching is currently not "
                           "supported for multimodal models in v0 and "
                           "has been disabled.")
1054
1055
            self.enable_prefix_caching = False

1056
        cache_config = CacheConfig(
1057
            block_size=self.block_size,
1058
1059
1060
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1061
            is_attention_free=model_config.is_attention_free,
1062
1063
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1064
1065
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
1066
            calculate_kv_scales=self.calculate_kv_scales,
1067
        )
1068
        parallel_config = ParallelConfig(
1069
1070
1071
1072
1073
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1074
1075
1076
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1077
            ),
1078
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1079
1080
1081
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
        )
1082

1083
1084
1085
1086
1087
1088
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
1089

1090
1091
1092
1093
1094
1095
            # For multimodal models, chunked prefill is disabled by default in
            # V0, but enabled by design in V1
            if model_config.is_multimodal_model:
                self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)

            elif use_long_context:
1096
1097
1098
1099
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
1100
                from vllm.platforms import current_platform
1101
1102
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
1103
                        and not self.enable_prompt_adapter
1104
1105
                        and model_config.runner_type != "pooling"
                        and not current_platform.is_rocm()):
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
1122
1123
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
1124
            msg = "Chunked prefill is not supported for pooling models"
1125
            raise ValueError(msg)
1126

1127

1128
1129
1130
1131
1132
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
1133
1134
            speculative_model_quantization = \
                self.speculative_model_quantization,
1135
1136
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
1137
            num_speculative_tokens=self.num_speculative_tokens,
1138
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
1139
1140
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
1141
1142
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
1143
            disable_log_stats=self.disable_log_stats,
1144
1145
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1146
1147
1148
1149
1150
1151
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
1152
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
1153
1154
        )

1155
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1156
        # If the feature combo become valid
1157
1158
1159
1160
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1161
1162
1163
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1164
1165
1166
1167
1168
1169
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1170
1171
1172
1173
1174
1175
1176
1177
1178

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        if not self.use_v2_block_manager:
            logger.warning(
                "[DEPRECATED] Block manager v1 has been removed, "
                "and setting --use-v2-block-manager to True or False has "
                "no effect on vLLM behavior. Please remove "
                "--use-v2-block-manager in your engine argument. "
                "If your use case is not supported by "
                "SelfAttnBlockSpaceManager (i.e. block manager v2),"
                " please file an issue with detailed information.")

1189
        scheduler_config = SchedulerConfig(
1190
            runner_type=model_config.runner_type,
1191
1192
1193
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1194
            num_lookahead_slots=num_lookahead_slots,
1195
1196
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1197
            is_multimodal_model=model_config.is_multimodal_model,
1198
            preemption_mode=self.preemption_mode,
1199
            num_scheduler_steps=self.num_scheduler_steps,
1200
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1201
1202
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1203
            policy=self.scheduling_policy)
1204
        lora_config = LoRAConfig(
1205
            bias_enabled=self.enable_lora_bias,
1206
1207
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1208
            fully_sharded_loras=self.fully_sharded_loras,
1209
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1210
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1211
1212
1213
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1214

1215
1216
1217
1218
1219
1220
1221
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1222
        load_config = self.create_load_config()
1223

1224
1225
1226
1227
1228
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1229
1230
1231
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

1232
1233
1234
1235
1236
1237
1238
1239
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1240
        observability_config = ObservabilityConfig(
1241
1242
1243
1244
1245
1246
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1247

1248
        config = VllmConfig(
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1259
            prompt_adapter_config=prompt_adapter_config,
1260
            compilation_config=self.compilation_config,
1261
            kv_transfer_config=self.kv_transfer_config,
1262
        )
1263

1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
        if envs.VLLM_USE_V1:
            self._override_v1_engine_config(config)
        return config

    def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
        """
        Override the EngineArgs's args based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

1274
1275
1276
1277
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
        # When no user override, set the default values based on the usage
        # context.
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
        # Use different default values for different hardware.
        from vllm.platforms import current_platform
        device_name = current_platform.get_device_name().lower()
        if "h100" in device_name or "h200" in device_name:
            # For H100 and H200, we use larger default values.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }

1294
1295
1296
1297
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1298
1299
1300
            logger.warning(
                "Setting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens, usage_context.value)
1301
1302
1303
1304
1305
1306
1307

    def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
        """
        Override the EngineConfig's configs based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

1308

1309
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1310
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1311
    """Arguments for asynchronous vLLM engine."""
1312
    disable_log_requests: bool = False
1313
1314

    @staticmethod
1315
1316
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1317
1318
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1319
1320
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1321
                            help='Disable logging requests.')
1322
        return parser
1323
1324
1325
1326


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1327
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1328
1329
1330


def _async_engine_args_parser():
1331
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1332
                                        async_args_only=True)