arg_utils.py 60.8 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
6
                    Tuple, Type, Union, cast, get_args)
7

8
9
import torch

10
import vllm.envs as envs
11
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
12
13
14
15
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
                         ModelConfig, ObservabilityConfig, ParallelConfig,
                         PoolerConfig, PromptAdapterConfig, SchedulerConfig,
16
17
                         SpeculativeConfig, TaskOption, TokenizerPoolConfig,
                         VllmConfig)
18
from vllm.executor.executor_base import ExecutorBase
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
21
from vllm.transformers_utils.utils import check_gguf_file
22
from vllm.usage.usage_lib import UsageContext
23
from vllm.utils import FlexibleArgumentParser, StoreBoolean
24

25
if TYPE_CHECKING:
26
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
27

28
29
logger = init_logger(__name__)

30
31
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

32
33
34
35
36
37
38
39
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
40
    "hpu",
41
42
]

43

44
45
46
47
48
49
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


50
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
51
52
53
54
55
56
57
58
59
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
60
61
62
63
64
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
65
66
67
68
69
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
70
71

        try:
72
            parsed_value = int(value)
73
74
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
75
76
77
78
79
80
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
81
82
83
84

    return out_dict


85
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
86
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
87
    """Arguments for vLLM engine."""
88
    model: str = 'facebook/opt-125m'
89
    served_model_name: Optional[Union[str, List[str]]] = None
90
    tokenizer: Optional[str] = None
91
    task: TaskOption = "auto"
92
    skip_tokenizer_init: bool = False
93
    tokenizer_mode: str = 'auto'
94
    trust_remote_code: bool = False
95
    allowed_local_media_path: str = ""
96
    download_dir: Optional[str] = None
97
    load_format: str = 'auto'
98
    config_format: ConfigFormat = ConfigFormat.AUTO
99
    dtype: str = 'auto'
100
    kv_cache_dtype: str = 'auto'
101
    seed: int = 0
102
    max_model_len: Optional[int] = None
103
    worker_use_ray: bool = False
104
105
106
107
108
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
109
    # number of P/D disaggregation (or other disaggregation) workers
110
111
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
112
    max_parallel_loading_workers: Optional[int] = None
113
    block_size: Optional[int] = None
114
    enable_prefix_caching: Optional[bool] = None
115
    disable_sliding_window: bool = False
116
    use_v2_block_manager: bool = True
117
118
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
119
    gpu_memory_utilization: float = 0.90
120
    max_num_batched_tokens: Optional[int] = None
121
    max_num_seqs: Optional[int] = None
122
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
123
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
124
    revision: Optional[str] = None
125
    code_revision: Optional[str] = None
126
    rope_scaling: Optional[Dict[str, Any]] = None
127
    rope_theta: Optional[float] = None
128
    hf_overrides: Optional[HfOverrides] = None
129
    tokenizer_revision: Optional[str] = None
130
    quantization: Optional[str] = None
131
    enforce_eager: Optional[bool] = None
132
    max_seq_len_to_capture: int = 8192
133
    disable_custom_all_reduce: bool = False
134
    tokenizer_pool_size: int = 0
135
136
137
138
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
139
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
140
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
141
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
142
    disable_mm_preprocessor_cache: bool = False
143
    enable_lora: bool = False
144
    enable_lora_bias: bool = False
145
146
    max_loras: int = 1
    max_lora_rank: int = 16
147
148
149
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
150
    fully_sharded_loras: bool = False
151
    lora_extra_vocab_size: int = 256
152
    long_lora_scaling_factors: Optional[Tuple[float]] = None
153
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
154
    max_cpu_loras: Optional[int] = None
155
    device: str = 'auto'
156
    num_scheduler_steps: int = 1
157
    multi_step_stream_outputs: bool = True
158
    ray_workers_use_nsight: bool = False
159
    num_gpu_blocks_override: Optional[int] = None
160
    num_lookahead_slots: int = 0
161
    model_loader_extra_config: Optional[dict] = None
162
    ignore_patterns: Optional[Union[str, List[str]]] = None
163
    preemption_mode: Optional[str] = None
164

165
    scheduler_delay_factor: float = 0.0
166
    enable_chunked_prefill: Optional[bool] = None
167

168
    guided_decoding_backend: str = 'xgrammar'
169
    logits_processor_pattern: Optional[str] = None
170
171
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
172
    speculative_model_quantization: Optional[str] = None
173
    speculative_draft_tensor_parallel_size: Optional[int] = None
174
    num_speculative_tokens: Optional[int] = None
175
    speculative_disable_mqa_scorer: Optional[bool] = False
176
    speculative_max_model_len: Optional[int] = None
177
    speculative_disable_by_batch_size: Optional[int] = None
178
179
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
180
181
182
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
183
    qlora_adapter_name_or_path: Optional[str] = None
184
    disable_logprobs_during_spec_decoding: Optional[bool] = None
185

186
    otlp_traces_endpoint: Optional[str] = None
187
    collect_detailed_traces: Optional[str] = None
188
    disable_async_output_proc: bool = False
189
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
190

191
192
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
193
    compilation_config: Optional[CompilationConfig] = None
194
    worker_cls: str = "auto"
195

196
197
    kv_transfer_config: Optional[KVTransferConfig] = None

198
    generation_config: Optional[str] = None
199
    enable_sleep_mode: bool = False
200

201
202
    calculate_kv_scales: Optional[bool] = None

203
    def __post_init__(self):
204
        if not self.tokenizer:
205
            self.tokenizer = self.model
206

207
208
209
210
        # Override the default value of enable_prefix_caching if it's not set
        # by user.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
211

212
213
214
        # Override max_num_seqs if it's not set by user.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
215

216
217
218
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
219
        if isinstance(self.compilation_config, (int, dict)):
220
221
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
222

223
        # Setup plugins
224
225
        from vllm.plugins import load_general_plugins
        load_general_plugins()
226
227

    @staticmethod
228
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
229
        """Shared CLI arguments for vLLM engine."""
230

231
        # Model arguments
232
233
234
        parser.add_argument(
            '--model',
            type=str,
235
            default=EngineArgs.model,
236
            help='Name or path of the huggingface model to use.')
237
238
239
240
241
242
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
243
            'multiple tasks. When the model only supports one task, ``"auto"`` '
244
245
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
246
247
        parser.add_argument(
            '--tokenizer',
248
            type=nullable_str,
249
            default=EngineArgs.tokenizer,
250
251
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
252
253
254
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
255
            help='Skip initialization of tokenizer and detokenizer.')
Jasmond L's avatar
Jasmond L committed
256
257
        parser.add_argument(
            '--revision',
258
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
259
            default=None,
260
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
261
262
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
263
264
        parser.add_argument(
            '--code-revision',
265
            type=nullable_str,
266
            default=None,
267
            help='The specific revision to use for the model code on '
268
269
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
270
271
        parser.add_argument(
            '--tokenizer-revision',
272
            type=nullable_str,
273
            default=None,
274
275
276
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
277
278
279
280
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
281
            choices=['auto', 'slow', 'mistral'],
282
283
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
284
285
            'always use the slow tokenizer. \n* '
            '"mistral" will always use the `mistral_common` tokenizer.')
286
287
        parser.add_argument('--trust-remote-code',
                            action='store_true',
288
                            help='Trust remote code from huggingface.')
289
290
291
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
292
293
294
295
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
296
        parser.add_argument('--download-dir',
297
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
298
                            default=EngineArgs.download_dir,
299
                            help='Directory to download and load the weights, '
300
                            'default to the default cache dir of '
301
                            'huggingface.')
302
303
304
305
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
306
            choices=[f.value for f in LoadFormat],
307
308
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
309
            'and fall back to the pytorch bin format if safetensors format '
310
311
312
313
314
315
316
317
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
318
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
319
            'section for more information.\n'
320
321
            '* "runai_streamer" will load the Safetensors weights using Run:ai'
            'Model Streamer \n'
322
323
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
324
325
326
327
328
329
330
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
331
332
333
334
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
338
339
340
341
342
343
344
345
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
346
347
348
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
349
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
350
            default=EngineArgs.kv_cache_dtype,
351
            help='Data type for kv cache storage. If "auto", will use model '
352
353
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
354
355
        parser.add_argument('--max-model-len',
                            type=int,
356
                            default=EngineArgs.max_model_len,
357
358
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
359
360
361
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
362
363
            default='xgrammar',
            choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
364
            help='Which engine will be used for guided decoding'
365
            ' (JSON schema / regex etc) by default. Currently support '
366
            'https://github.com/outlines-dev/outlines, '
367
            'https://github.com/mlc-ai/xgrammar, and '
368
369
370
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
371
372
373
374
375
376
377
378
        parser.add_argument(
            '--logits-processor-pattern',
            type=nullable_str,
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
379
        # Parallel arguments
380
381
        parser.add_argument(
            '--distributed-executor-backend',
382
            choices=['ray', 'mp', 'uni', 'external_launcher'],
383
            default=EngineArgs.distributed_executor_backend,
384
385
386
387
388
389
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
390
            'only supports Ray for distributed inference.')
391

392
393
394
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
395
            help='Deprecated, use ``--distributed-executor-backend=ray``.')
396
397
398
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
399
                            default=EngineArgs.pipeline_parallel_size,
400
                            help='Number of pipeline stages.')
401
402
403
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
404
                            default=EngineArgs.tensor_parallel_size,
405
                            help='Number of tensor parallel replicas.')
406
407
408
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
409
            default=EngineArgs.max_parallel_loading_workers,
410
            help='Load model sequentially in multiple batches, '
411
            'to avoid RAM OOM when using tensor '
412
            'parallel and large models.')
413
414
415
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
416
            help='If specified, use nsight to profile Ray workers.')
417
        # KV cache arguments
418
419
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
420
                            default=EngineArgs.block_size,
421
                            choices=[8, 16, 32, 64, 128],
422
                            help='Token block size for contiguous chunks of '
423
                            'tokens. This is ignored on neuron devices and '
424
                            'set to ``--max-model-len``. On CUDA devices, '
425
426
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
427

428
429
430
431
432
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
433
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
434
        )
435
436
437
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
438
                            'capping to sliding window size.')
439
440
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
441
                            default=True,
442
443
444
445
446
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
447
448
449
450
451
452
453
454
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
455

456
457
458
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
459
                            help='Random seed for operations.')
460
        parser.add_argument('--swap-space',
461
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
462
                            default=EngineArgs.swap_space,
463
                            help='CPU swap space size (GiB) per GPU.')
464
465
466
467
468
469
470
471
472
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
473
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
474
            'which requires at least 26GB GPU memory. Note that this '
475
            'requires fast CPU-GPU interconnect, as part of the model is '
476
477
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
478
479
480
481
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
482
483
484
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
485
486
487
488
489
490
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
491
        parser.add_argument(
492
            '--num-gpu-blocks-override',
493
494
495
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
496
            ' of GPU blocks. Used for testing preemption.')
497
498
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
499
                            default=EngineArgs.max_num_batched_tokens,
500
501
                            help='Maximum number of batched tokens per '
                            'iteration.')
502
503
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
504
                            default=EngineArgs.max_num_seqs,
505
                            help='Maximum number of sequences per iteration.')
506
507
508
509
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
510
511
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
512
513
        parser.add_argument('--disable-log-stats',
                            action='store_true',
514
                            help='Disable logging statistics.')
515
516
517
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
518
                            type=nullable_str,
519
                            choices=[*QUANTIZATION_METHODS, None],
520
                            default=EngineArgs.quantization,
521
522
523
524
525
526
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
527
528
529
530
531
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
532
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
533
534
535
536
537
538
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
539
540
541
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
542
                            help='Extra arguments for the HuggingFace config. '
543
544
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
545
546
547
548
549
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
550
        parser.add_argument('--max-seq-len-to-capture',
551
552
553
554
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
555
556
557
558
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
559
560
561
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
562
                            help='See ParallelConfig.')
563
564
565
566
567
568
569
570
571
572
573
574
575
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
576
                            type=nullable_str,
577
578
579
580
581
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
582
583
584
585
586
587
588
589
590
591
592
593
594
595

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
596
597
598
599
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
600
            help=('Overrides for the multimodal input mapping/processing, '
601
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
602
        parser.add_argument(
603
            '--disable-mm-preprocessor-cache',
604
            action='store_true',
605
606
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
607

608
609
610
611
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
612
613
614
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
634
            choices=['auto', 'float16', 'bfloat16'],
635
636
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
637
638
639
640
641
642
643
644
645
646
647
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
648
649
650
651
652
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
653
654
                  'Must be >= than max_loras. '
                  'Defaults to max_loras.'))
655
656
657
658
659
660
661
662
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
663
664
665
666
667
668
669
670
671
672
673
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
674
675
676
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
677
                            choices=DEVICE_OPTIONS,
678
                            help='Device type for vLLM execution.')
679
680
681
682
683
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
684

685
686
        parser.add_argument(
            '--multi-step-stream-outputs',
687
688
689
690
691
692
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
693
694
695
696
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
697
            help='Apply a delay (of delay factor multiplied by previous '
698
            'prompt latency) before scheduling next prompt.')
699
700
        parser.add_argument(
            '--enable-chunked-prefill',
701
702
703
704
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
705
            help='If set, the prefill requests can be chunked based on the '
706
            'max_num_batched_tokens.')
707
708
709

        parser.add_argument(
            '--speculative-model',
710
            type=nullable_str,
711
            default=EngineArgs.speculative_model,
712
713
            help=
            'The name of the draft model to be used in speculative decoding.')
714
715
716
717
718
719
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
720
            help='Method used to quantize the weights of speculative model. '
721
722
723
724
725
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
726
727
728
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
729
            default=EngineArgs.num_speculative_tokens,
730
            help='The number of speculative tokens to sample from '
731
            'the draft model in speculative decoding.')
732
733
734
735
736
737
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
738
739
740
741
742
743
744
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
745

746
747
        parser.add_argument(
            '--speculative-max-model-len',
748
            type=int,
749
750
751
752
753
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

754
755
756
757
758
759
760
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

761
762
763
764
765
766
767
768
769
770
771
772
773
774
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

807
808
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
809
            action=StoreBoolean,
810
            default=EngineArgs.disable_logprobs_during_spec_decoding,
811
812
            nargs="?",
            const="True",
813
814
815
816
817
818
819
820
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

821
        parser.add_argument('--model-loader-extra-config',
822
                            type=nullable_str,
823
824
825
826
827
828
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
829
830
831
832
833
834
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
835
            "Default to `original/**/*` to avoid repeated loading of llama's "
836
            "checkpoints.")
837
        parser.add_argument(
838
            '--preemption-mode',
839
840
            type=str,
            default=None,
841
842
843
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
844

845
846
847
848
849
850
851
852
853
854
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
855
            "same as the ``--model`` argument. Noted that this name(s) "
856
            "will also be used in `model_name` tag content of "
857
            "prometheus metrics, if multiple names provided, metrics "
858
            "tag will take the first one.")
859
860
861
862
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
863
864
865
866
867
868

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
869
870
871
872
873
874
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
875
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
876
877
878
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
879

880
881
882
883
884
885
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
886

887
888
889
890
891
892
893
894
895
896
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

897
        parser.add_argument(
898
899
            '--override-neuron-config',
            type=json.loads,
900
            default=None,
901
            help="Override or set neuron device configuration. "
902
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
903
        parser.add_argument(
904
905
            '--override-pooler-config',
            type=PoolerConfig.from_json,
906
            default=None,
907
            help="Override or set the pooling method for pooling models. "
908
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
909

910
911
912
913
914
915
916
917
918
919
920
921
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
922
923
924
925
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
926

927
928
929
930
931
932
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

933
934
935
936
937
938
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')

939
940
941
942
943
944
945
946
947
948
        parser.add_argument(
            "--generation-config",
            type=nullable_str,
            default=None,
            help="The folder path to the generation config. "
            "Defaults to None, will use the default generation config in vLLM. "
            "If set to 'auto', the generation config will be automatically "
            "loaded from model. If set to a folder path, the generation config "
            "will be loaded from the specified folder path.")

949
950
951
952
953
954
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

955
956
957
958
959
960
961
962
963
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')

964
        return parser
965
966

    @classmethod
967
    def from_cli_args(cls, args: argparse.Namespace):
968
969
970
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
971
972
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
973

974
975
    def create_model_config(self) -> ModelConfig:
        return ModelConfig(
976
            model=self.model,
977
            task=self.task,
978
979
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
980
981
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
982
            allowed_local_media_path=self.allowed_local_media_path,
983
984
985
986
987
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
988
            rope_theta=self.rope_theta,
989
            hf_overrides=self.hf_overrides,
990
991
992
993
994
995
996
997
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
998
            served_model_name=self.served_model_name,
999
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1000
            use_async_output_proc=not self.disable_async_output_proc,
1001
            config_format=self.config_format,
1002
            mm_processor_kwargs=self.mm_processor_kwargs,
1003
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1004
1005
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1006
            logits_processor_pattern=self.logits_processor_pattern,
1007
1008
1009
            generation_config=self.generation_config,
            enable_sleep_mode=self.enable_sleep_mode,
        )
1010

1011
1012
1013
1014
1015
1016
1017
1018
    def create_load_config(self) -> LoadConfig:
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
        )

1019
1020
1021
1022
1023
1024
    def create_engine_config(self,
                             usage_context: Optional[UsageContext] = None
                             ) -> VllmConfig:
        if envs.VLLM_USE_V1:
            self._override_v1_engine_args(usage_context)

1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

        device_config = DeviceConfig(device=self.device)
        model_config = self.create_model_config()

1052
1053
1054
1055
1056
        if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
                and self.enable_prefix_caching):
            logger.warning("--enable-prefix-caching is currently not "
                           "supported for multimodal models in v0 and "
                           "has been disabled.")
1057
1058
            self.enable_prefix_caching = False

1059
        cache_config = CacheConfig(
1060
            block_size=self.block_size,
1061
1062
1063
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1064
            is_attention_free=model_config.is_attention_free,
1065
1066
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1067
1068
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
1069
            calculate_kv_scales=self.calculate_kv_scales,
1070
        )
1071
        parallel_config = ParallelConfig(
1072
1073
1074
1075
1076
1077
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1078
1079
1080
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1081
            ),
1082
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1083
1084
1085
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
        )
1086

1087
1088
1089
1090
1091
1092
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
1093

1094
1095
1096
1097
1098
1099
            # For multimodal models, chunked prefill is disabled by default in
            # V0, but enabled by design in V1
            if model_config.is_multimodal_model:
                self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)

            elif use_long_context:
1100
1101
1102
1103
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
1104
                from vllm.platforms import current_platform
1105
1106
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
1107
                        and not self.enable_prompt_adapter
1108
1109
                        and model_config.runner_type != "pooling"
                        and not current_platform.is_rocm()):
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
1126
1127
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
1128
            msg = "Chunked prefill is not supported for pooling models"
1129
            raise ValueError(msg)
1130

1131

1132
1133
1134
1135
1136
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
1137
1138
            speculative_model_quantization = \
                self.speculative_model_quantization,
1139
1140
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
1141
            num_speculative_tokens=self.num_speculative_tokens,
1142
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
1143
1144
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
1145
1146
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
1147
            disable_log_stats=self.disable_log_stats,
1148
1149
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1150
1151
1152
1153
1154
1155
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
1156
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
1157
1158
        )

1159
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1160
        # If the feature combo become valid
1161
1162
1163
1164
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1165
1166
1167
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1168
1169
1170
1171
1172
1173
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1174
1175
1176
1177
1178
1179
1180
1181
1182

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
        if not self.use_v2_block_manager:
            logger.warning(
                "[DEPRECATED] Block manager v1 has been removed, "
                "and setting --use-v2-block-manager to True or False has "
                "no effect on vLLM behavior. Please remove "
                "--use-v2-block-manager in your engine argument. "
                "If your use case is not supported by "
                "SelfAttnBlockSpaceManager (i.e. block manager v2),"
                " please file an issue with detailed information.")

1193
        scheduler_config = SchedulerConfig(
1194
            runner_type=model_config.runner_type,
1195
1196
1197
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1198
            num_lookahead_slots=num_lookahead_slots,
1199
1200
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1201
            is_multimodal_model=model_config.is_multimodal_model,
1202
            preemption_mode=self.preemption_mode,
1203
            num_scheduler_steps=self.num_scheduler_steps,
1204
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1205
1206
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1207
            policy=self.scheduling_policy)
1208
        lora_config = LoRAConfig(
1209
            bias_enabled=self.enable_lora_bias,
1210
1211
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1212
            fully_sharded_loras=self.fully_sharded_loras,
1213
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1214
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1215
1216
1217
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1218

1219
1220
1221
1222
1223
1224
1225
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1226
        load_config = self.create_load_config()
1227

1228
1229
1230
1231
1232
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1233
1234
1235
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

1236
1237
1238
1239
1240
1241
1242
1243
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1244
        observability_config = ObservabilityConfig(
1245
1246
1247
1248
1249
1250
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1251

1252
        config = VllmConfig(
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1263
            prompt_adapter_config=prompt_adapter_config,
1264
            compilation_config=self.compilation_config,
1265
            kv_transfer_config=self.kv_transfer_config,
1266
        )
1267

1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
        if envs.VLLM_USE_V1:
            self._override_v1_engine_config(config)
        return config

    def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
        """
        Override the EngineArgs's args based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

1278
1279
1280
1281
1282
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
        # When no user override, set the default values based on the usage
        # context.
        # TODO(woosuk): Tune the default values for different hardware.
1283
1284
1285
1286
1287
1288
1289
1290
        default_max_num_batched_tokens = {
            UsageContext.LLM_CLASS: 8192,
            UsageContext.OPENAI_API_SERVER: 2048,
        }
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1291
1292
1293
            logger.warning(
                "Setting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens, usage_context.value)
1294
1295
1296
1297
1298
1299
1300

    def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
        """
        Override the EngineConfig's configs based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

1301

1302
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1303
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1304
    """Arguments for asynchronous vLLM engine."""
1305
    disable_log_requests: bool = False
1306
1307

    @staticmethod
1308
1309
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1310
1311
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1312
1313
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1314
                            help='Disable logging requests.')
1315
        return parser
1316
1317
1318
1319


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1320
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1321
1322
1323


def _async_engine_args_parser():
1324
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1325
                                        async_args_only=True)