arg_utils.py 48 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
6
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
                    Union)
7

8
9
import torch

10
import vllm.envs as envs
11
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
12
13
                         EngineConfig, LoadConfig, LoadFormat, LoRAConfig,
                         ModelConfig, ObservabilityConfig, ParallelConfig,
14
15
                         PromptAdapterConfig, SchedulerConfig,
                         SpeculativeConfig, TokenizerPoolConfig)
16
from vllm.executor.executor_base import ExecutorBase
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
19
from vllm.utils import FlexibleArgumentParser
20

21
if TYPE_CHECKING:
22
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
23

24
25
logger = init_logger(__name__)

26
27
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

28

29
30
31
32
33
34
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
        try:
            key, value = item.split("=")
        except TypeError as exc:
            msg = "Each item should be in the form KEY=VALUE"
            raise ValueError(msg) from exc

        try:
            out_dict[key] = int(value)
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
            raise ValueError(msg) from exc

    return out_dict


56
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
57
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
58
    """Arguments for vLLM engine."""
59
    model: str = 'facebook/opt-125m'
60
    served_model_name: Optional[Union[str, List[str]]] = None
61
    tokenizer: Optional[str] = None
62
    skip_tokenizer_init: bool = False
63
    tokenizer_mode: str = 'auto'
64
    trust_remote_code: bool = False
65
    download_dir: Optional[str] = None
66
    load_format: str = 'auto'
67
    dtype: str = 'auto'
68
    kv_cache_dtype: str = 'auto'
69
    quantization_param_path: Optional[str] = None
70
    seed: int = 0
71
    max_model_len: Optional[int] = None
72
    worker_use_ray: bool = False
73
74
75
76
77
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
78
79
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
80
    max_parallel_loading_workers: Optional[int] = None
81
    block_size: int = 16
82
    enable_prefix_caching: bool = False
83
    disable_sliding_window: bool = False
84
    use_v2_block_manager: bool = False
85
86
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
87
    gpu_memory_utilization: float = 0.90
88
    max_num_batched_tokens: Optional[int] = None
89
    max_num_seqs: int = 256
90
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
91
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
92
    revision: Optional[str] = None
93
    code_revision: Optional[str] = None
94
    rope_scaling: Optional[dict] = None
95
    rope_theta: Optional[float] = None
96
    tokenizer_revision: Optional[str] = None
97
    quantization: Optional[str] = None
98
    enforce_eager: Optional[bool] = None
99
100
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
101
    disable_custom_all_reduce: bool = False
102
    tokenizer_pool_size: int = 0
103
104
105
106
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
107
    tokenizer_pool_extra_config: Optional[dict] = None
108
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
109
110
111
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
112
113
114
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
115
    fully_sharded_loras: bool = False
116
    lora_extra_vocab_size: int = 256
117
    long_lora_scaling_factors: Optional[Tuple[float]] = None
118
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
119
    max_cpu_loras: Optional[int] = None
120
    device: str = 'auto'
121
    num_scheduler_steps: int = 1
122
    ray_workers_use_nsight: bool = False
123
    num_gpu_blocks_override: Optional[int] = None
124
    num_lookahead_slots: int = 0
125
    model_loader_extra_config: Optional[dict] = None
126
    ignore_patterns: Optional[Union[str, List[str]]] = None
127
    preemption_mode: Optional[str] = None
128

129
    scheduler_delay_factor: float = 0.0
130
    enable_chunked_prefill: Optional[bool] = None
131

132
    guided_decoding_backend: str = 'outlines'
133
134
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
135
    speculative_model_quantization: Optional[str] = None
136
    speculative_draft_tensor_parallel_size: Optional[int] = None
137
    num_speculative_tokens: Optional[int] = None
138
    speculative_max_model_len: Optional[int] = None
139
    speculative_disable_by_batch_size: Optional[int] = None
140
141
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
142
143
144
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
145
    qlora_adapter_name_or_path: Optional[str] = None
146
    disable_logprobs_during_spec_decoding: Optional[bool] = None
147

148
    otlp_traces_endpoint: Optional[str] = None
149
    collect_detailed_traces: Optional[str] = None
150

151
    def __post_init__(self):
152
153
        if self.tokenizer is None:
            self.tokenizer = self.model
154
155

    @staticmethod
156
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        """Shared CLI arguments for vLLM engine."""
158

159
        # Model arguments
160
161
162
        parser.add_argument(
            '--model',
            type=str,
163
            default=EngineArgs.model,
164
            help='Name or path of the huggingface model to use.')
165
166
        parser.add_argument(
            '--tokenizer',
167
            type=nullable_str,
168
            default=EngineArgs.tokenizer,
169
170
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
171
172
173
174
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
175
176
        parser.add_argument(
            '--revision',
177
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
178
            default=None,
179
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
180
181
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
182
183
        parser.add_argument(
            '--code-revision',
184
            type=nullable_str,
185
            default=None,
186
            help='The specific revision to use for the model code on '
187
188
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
189
190
        parser.add_argument(
            '--tokenizer-revision',
191
            type=nullable_str,
192
            default=None,
193
194
195
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
196
197
198
199
200
201
202
203
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
204
205
        parser.add_argument('--trust-remote-code',
                            action='store_true',
206
                            help='Trust remote code from huggingface.')
207
        parser.add_argument('--download-dir',
208
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
209
                            default=EngineArgs.download_dir,
210
                            help='Directory to download and load the weights, '
211
                            'default to the default cache dir of '
212
                            'huggingface.')
213
214
215
216
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
217
            choices=[f.value for f in LoadFormat],
218
219
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
220
            'and fall back to the pytorch bin format if safetensors format '
221
222
223
224
225
226
227
228
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
229
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
230
231
232
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
233
234
235
236
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
239
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
240
241
242
243
244
245
246
247
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
248
249
250
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
251
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
252
            default=EngineArgs.kv_cache_dtype,
253
            help='Data type for kv cache storage. If "auto", will use model '
254
255
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
256
257
        parser.add_argument(
            '--quantization-param-path',
258
            type=nullable_str,
259
260
261
262
263
264
265
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
266
            'supported for common inference criteria.')
267
268
        parser.add_argument('--max-model-len',
                            type=int,
269
                            default=EngineArgs.max_model_len,
270
271
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
272
273
274
275
276
277
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
278
279
280
281
282
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
283
        # Parallel arguments
284
285
286
287
288
289
290
291
292
293
294
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
295
296
297
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
298
                            default=EngineArgs.pipeline_parallel_size,
299
                            help='Number of pipeline stages.')
300
301
302
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
303
                            default=EngineArgs.tensor_parallel_size,
304
                            help='Number of tensor parallel replicas.')
305
306
307
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
308
            default=EngineArgs.max_parallel_loading_workers,
309
            help='Load model sequentially in multiple batches, '
310
            'to avoid RAM OOM when using tensor '
311
            'parallel and large models.')
312
313
314
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
315
            help='If specified, use nsight to profile Ray workers.')
316
        # KV cache arguments
317
318
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
319
                            default=EngineArgs.block_size,
320
                            choices=[8, 16, 32],
321
                            help='Token block size for contiguous chunks of '
322
323
                            'tokens. This is ignored on neuron devices and '
                            'set to max-model-len')
324
325
326

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
327
                            help='Enables automatic prefix caching.')
328
329
330
331
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
332
333
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
334
                            help='Use BlockSpaceMangerV2.')
335
336
337
338
339
340
341
342
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
343

344
345
346
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
347
                            help='Random seed for operations.')
348
        parser.add_argument('--swap-space',
349
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
350
                            default=EngineArgs.swap_space,
351
                            help='CPU swap space size (GiB) per GPU.')
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
            'which requires at least 26GB GPU memory. Note that this '
            'requires fast CPU-GPU interconnect, as part of the model is'
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
366
367
368
369
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
370
371
372
373
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
374
        parser.add_argument(
375
            '--num-gpu-blocks-override',
376
377
378
379
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
380
381
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
382
                            default=EngineArgs.max_num_batched_tokens,
383
384
                            help='Maximum number of batched tokens per '
                            'iteration.')
385
386
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
387
                            default=EngineArgs.max_num_seqs,
388
                            help='Maximum number of sequences per iteration.')
389
390
391
392
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
393
394
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
395
396
        parser.add_argument('--disable-log-stats',
                            action='store_true',
397
                            help='Disable logging statistics.')
398
399
400
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
401
                            type=nullable_str,
402
                            choices=[*QUANTIZATION_METHODS, None],
403
                            default=EngineArgs.quantization,
404
405
406
407
408
409
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
410
411
412
413
414
        parser.add_argument('--rope-scaling',
                            default=None,
                            type=json.loads,
                            help='RoPE scaling configuration in JSON format. '
                            'For example, {"type":"dynamic","factor":2.0}')
415
416
417
418
419
420
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
421
422
423
424
425
426
427
428
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
429
                            help='Maximum context length covered by CUDA '
430
                            'graphs. When a sequence has context length '
431
                            'larger than this, we fall back to eager mode. '
432
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
433
                            ')')
434
        parser.add_argument('--max-seq-len-to-capture',
435
436
437
438
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
439
                            'larger than this, we fall back to eager mode.')
440
441
442
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
443
                            help='See ParallelConfig.')
444
445
446
447
448
449
450
451
452
453
454
455
456
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
457
                            type=nullable_str,
458
459
460
461
462
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
504
505
506
507
508
509
510
511
512
513
514
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
515
516
517
518
519
520
521
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
522
523
524
525
526
527
528
529
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
530
531
532
533
534
535
536
537
538
539
540
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
541
542
543
544
545
546
547
548
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
                            choices=[
                                "auto", "cuda", "neuron", "cpu", "openvino",
                                "tpu", "xpu"
                            ],
                            help='Device type for vLLM execution.')
549
550
551
552
553
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
554

555
556
557
558
559
560
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
561
562
        parser.add_argument(
            '--enable-chunked-prefill',
563
564
565
566
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
567
            help='If set, the prefill requests can be chunked based on the '
568
            'max_num_batched_tokens.')
569
570
571

        parser.add_argument(
            '--speculative-model',
572
            type=nullable_str,
573
            default=EngineArgs.speculative_model,
574
575
            help=
            'The name of the draft model to be used in speculative decoding.')
576
577
578
579
580
581
582
583
584
585
586
587
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
            help='Method used to quantize the weights of speculative model.'
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
588
589
590
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
591
            default=EngineArgs.num_speculative_tokens,
592
            help='The number of speculative tokens to sample from '
593
            'the draft model in speculative decoding.')
594
595
596
597
598
599
600
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
601

602
603
        parser.add_argument(
            '--speculative-max-model-len',
604
            type=int,
605
606
607
608
609
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

610
611
612
613
614
615
616
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

617
618
619
620
621
622
623
624
625
626
627
628
629
630
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

663
664
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
665
            action=StoreBoolean,
666
            default=EngineArgs.disable_logprobs_during_spec_decoding,
667
668
            nargs="?",
            const="True",
669
670
671
672
673
674
675
676
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

677
        parser.add_argument('--model-loader-extra-config',
678
                            type=nullable_str,
679
680
681
682
683
684
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
685
686
687
688
689
690
691
692
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
            "Default to 'original/**/*' to avoid repeated loading of llama's "
            "checkpoints.")
693
        parser.add_argument(
694
            '--preemption-mode',
695
696
            type=str,
            default=None,
697
698
699
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
700

701
702
703
704
705
706
707
708
709
710
711
712
713
714
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
715
716
717
718
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
719
720
721
722
723
724

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
725
726
727
728
729
730
731
732
733
734
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
            ". It makes sense to set this only if --otlp-traces-endpoint is"
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
735

736
        return parser
737
738

    @classmethod
739
    def from_cli_args(cls, args: argparse.Namespace):
740
741
742
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
743
744
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
745

746
    def create_engine_config(self) -> EngineConfig:
747
748
749
        # gguf file needs a specific model loader and doesn't use hf_repo
        if self.model.endswith(".gguf"):
            self.quantization = self.load_format = "gguf"
750
751
752
753

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
754
755
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
756
757
758
759
760
761
762
763
764
765
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")
766
767
768
769
770

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

771
        device_config = DeviceConfig(device=self.device)
772
        model_config = ModelConfig(
773
774
775
776
777
778
779
780
781
            model=self.model,
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
782
            rope_theta=self.rope_theta,
783
784
785
786
787
788
789
790
791
792
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_context_len_to_capture=self.max_context_len_to_capture,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
793
            served_model_name=self.served_model_name,
794
795
            limit_mm_per_prompt=self.limit_mm_per_prompt,
        )
796
        cache_config = CacheConfig(
797
798
            block_size=self.block_size if self.device != "neuron" else
            self.max_model_len,  # neuron needs block_size = max_model_len
799
800
801
802
803
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
804
805
806
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
        )
807
        parallel_config = ParallelConfig(
808
809
810
811
812
813
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
814
815
816
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
817
            ),
818
            ray_workers_use_nsight=self.ray_workers_use_nsight,
819
            distributed_executor_backend=self.distributed_executor_backend)
820

821
822
823
824
825
826
827
828
829
830
831
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
            if use_long_context:
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
832
833
834
                has_seqlen_agnostic_layers = (
                    model_config.contains_seqlen_agnostic_layers(
                        parallel_config))
835
836
837
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
838
839
                        and not self.enable_prefix_caching
                        and not has_seqlen_agnostic_layers):
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)

857
858
859
860
861
862
        if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
            self.use_v2_block_manager = True
            logger.warning(
                "Enabled BlockSpaceManagerV2 because it is "
                "required for multi-step (--num-scheduler-steps > 1)")

863
864
865
866
867
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
868
869
            speculative_model_quantization = \
                self.speculative_model_quantization,
870
871
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
872
            num_speculative_tokens=self.num_speculative_tokens,
873
874
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
875
876
877
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
878
            disable_log_stats=self.disable_log_stats,
879
880
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
881
882
883
884
885
886
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
887
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
888
889
        )

890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
            if self.enable_chunked_prefill:
                raise ValueError("Chunked prefill is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

906
        scheduler_config = SchedulerConfig(
907
908
909
910
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
            use_v2_block_manager=self.use_v2_block_manager,
911
            num_lookahead_slots=num_lookahead_slots,
912
913
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
914
            embedding_mode=model_config.embedding_mode,
915
            preemption_mode=self.preemption_mode,
916
            num_scheduler_steps=self.num_scheduler_steps,
917
918
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
919
        )
920
921
922
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
923
            fully_sharded_loras=self.fully_sharded_loras,
924
            lora_extra_vocab_size=self.lora_extra_vocab_size,
925
            long_lora_scaling_factors=self.long_lora_scaling_factors,
926
927
928
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
929

930
931
932
933
934
935
936
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

937
938
939
940
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
941
            ignore_patterns=self.ignore_patterns,
942
943
        )

944
945
946
947
948
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

949
950
951
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

952
953
954
955
956
957
958
959
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
960
        observability_config = ObservabilityConfig(
961
962
963
964
965
966
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
967

968
        if (model_config.get_sliding_window() is not None
969
970
                and scheduler_config.chunked_prefill_enabled
                and not scheduler_config.use_v2_block_manager):
971
            raise ValueError(
972
973
                "Chunked prefill is not supported with sliding window. "
                "Set --disable-sliding-window to disable sliding window.")
974

975
976
977
978
979
980
981
982
983
984
985
        return EngineConfig(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
986
            prompt_adapter_config=prompt_adapter_config,
987
        )
988
989


990
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
991
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
992
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
993
    engine_use_ray: bool = False
994
    disable_log_requests: bool = False
995
996

    @staticmethod
997
998
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
999
1000
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1001
1002
        parser.add_argument('--engine-use-ray',
                            action='store_true',
1003
                            help='Use Ray to start the LLM engine in a '
1004
1005
1006
1007
1008
1009
1010
                            'separate process as the server process.'
                            '(DEPRECATED. This argument is deprecated '
                            'and will be removed in a future update. '
                            'Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to force '
                            'use it. See '
                            'https://github.com/vllm-project/vllm/issues/7045.'
                            ')')
1011
1012
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1013
                            help='Disable logging requests.')
1014
        return parser
1015
1016


1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
class StoreBoolean(argparse.Action):

    def __call__(self, parser, namespace, values, option_string=None):
        if values.lower() == "true":
            setattr(namespace, self.dest, True)
        elif values.lower() == "false":
            setattr(namespace, self.dest, False)
        else:
            raise ValueError(f"Invalid boolean value: {values}. "
                             "Expected 'true' or 'false'.")


1029
1030
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1031
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1032
1033
1034


def _async_engine_args_parser():
1035
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1036
                                        async_args_only=True)