arg_utils.py 49.1 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
6
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
                    Type, Union)
7

8
9
import torch

10
import vllm.envs as envs
11
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
12
13
                         EngineConfig, LoadConfig, LoadFormat, LoRAConfig,
                         ModelConfig, ObservabilityConfig, ParallelConfig,
14
15
                         PromptAdapterConfig, SchedulerConfig,
                         SpeculativeConfig, TokenizerPoolConfig)
16
from vllm.executor.executor_base import ExecutorBase
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
19
from vllm.transformers_utils.utils import check_gguf_file
20
from vllm.utils import FlexibleArgumentParser
21

22
if TYPE_CHECKING:
23
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
24

25
26
logger = init_logger(__name__)

27
28
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

29

30
31
32
33
34
35
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
        try:
            key, value = item.split("=")
        except TypeError as exc:
            msg = "Each item should be in the form KEY=VALUE"
            raise ValueError(msg) from exc

        try:
            out_dict[key] = int(value)
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
            raise ValueError(msg) from exc

    return out_dict


57
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
58
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
59
    """Arguments for vLLM engine."""
60
    model: str = 'facebook/opt-125m'
61
    served_model_name: Optional[Union[str, List[str]]] = None
62
    tokenizer: Optional[str] = None
63
    skip_tokenizer_init: bool = False
64
    tokenizer_mode: str = 'auto'
65
    trust_remote_code: bool = False
66
    download_dir: Optional[str] = None
67
    load_format: str = 'auto'
68
    dtype: str = 'auto'
69
    kv_cache_dtype: str = 'auto'
70
    quantization_param_path: Optional[str] = None
71
    seed: int = 0
72
    max_model_len: Optional[int] = None
73
    worker_use_ray: bool = False
74
75
76
77
78
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
79
80
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
81
    max_parallel_loading_workers: Optional[int] = None
82
    block_size: int = 16
83
    enable_prefix_caching: bool = False
84
    disable_sliding_window: bool = False
85
    use_v2_block_manager: bool = False
86
87
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
88
    gpu_memory_utilization: float = 0.90
89
    max_num_batched_tokens: Optional[int] = None
90
    max_num_seqs: int = 256
91
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
92
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
93
    revision: Optional[str] = None
94
    code_revision: Optional[str] = None
95
    rope_scaling: Optional[dict] = None
96
    rope_theta: Optional[float] = None
97
    tokenizer_revision: Optional[str] = None
98
    quantization: Optional[str] = None
99
    enforce_eager: Optional[bool] = None
100
101
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
102
    disable_custom_all_reduce: bool = False
103
    tokenizer_pool_size: int = 0
104
105
106
107
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
108
    tokenizer_pool_extra_config: Optional[dict] = None
109
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
110
111
112
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
113
114
115
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
116
    fully_sharded_loras: bool = False
117
    lora_extra_vocab_size: int = 256
118
    long_lora_scaling_factors: Optional[Tuple[float]] = None
119
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
120
    max_cpu_loras: Optional[int] = None
121
    device: str = 'auto'
122
    num_scheduler_steps: int = 1
123
    ray_workers_use_nsight: bool = False
124
    num_gpu_blocks_override: Optional[int] = None
125
    num_lookahead_slots: int = 0
126
    model_loader_extra_config: Optional[dict] = None
127
    ignore_patterns: Optional[Union[str, List[str]]] = None
128
    preemption_mode: Optional[str] = None
129

130
    scheduler_delay_factor: float = 0.0
131
    enable_chunked_prefill: Optional[bool] = None
132

133
    guided_decoding_backend: str = 'outlines'
134
135
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
136
    speculative_model_quantization: Optional[str] = None
137
    speculative_draft_tensor_parallel_size: Optional[int] = None
138
    num_speculative_tokens: Optional[int] = None
139
    speculative_max_model_len: Optional[int] = None
140
    speculative_disable_by_batch_size: Optional[int] = None
141
142
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
143
144
145
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
146
    qlora_adapter_name_or_path: Optional[str] = None
147
    disable_logprobs_during_spec_decoding: Optional[bool] = None
148

149
    otlp_traces_endpoint: Optional[str] = None
150
    collect_detailed_traces: Optional[str] = None
151
    disable_async_output_proc: bool = False
152
    override_neuron_config: Optional[Dict[str, Any]] = None
153

154
    def __post_init__(self):
155
156
        if self.tokenizer is None:
            self.tokenizer = self.model
157
158

    @staticmethod
159
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
160
        """Shared CLI arguments for vLLM engine."""
161

162
        # Model arguments
163
164
165
        parser.add_argument(
            '--model',
            type=str,
166
            default=EngineArgs.model,
167
            help='Name or path of the huggingface model to use.')
168
169
        parser.add_argument(
            '--tokenizer',
170
            type=nullable_str,
171
            default=EngineArgs.tokenizer,
172
173
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
174
175
176
177
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
178
179
        parser.add_argument(
            '--revision',
180
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
181
            default=None,
182
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
183
184
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
185
186
        parser.add_argument(
            '--code-revision',
187
            type=nullable_str,
188
            default=None,
189
            help='The specific revision to use for the model code on '
190
191
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
192
193
        parser.add_argument(
            '--tokenizer-revision',
194
            type=nullable_str,
195
            default=None,
196
197
198
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
199
200
201
202
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
203
            choices=['auto', 'slow', 'mistral'],
204
205
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
206
207
            'always use the slow tokenizer. \n* '
            '"mistral" will always use the `mistral_common` tokenizer.')
208
209
        parser.add_argument('--trust-remote-code',
                            action='store_true',
210
                            help='Trust remote code from huggingface.')
211
        parser.add_argument('--download-dir',
212
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
213
                            default=EngineArgs.download_dir,
214
                            help='Directory to download and load the weights, '
215
                            'default to the default cache dir of '
216
                            'huggingface.')
217
218
219
220
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
221
            choices=[f.value for f in LoadFormat],
222
223
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
224
            'and fall back to the pytorch bin format if safetensors format '
225
226
227
228
229
230
231
232
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
233
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
234
235
236
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
237
238
239
240
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
241
242
243
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
244
245
246
247
248
249
250
251
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
252
253
254
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
255
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
256
            default=EngineArgs.kv_cache_dtype,
257
            help='Data type for kv cache storage. If "auto", will use model '
258
259
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
260
261
        parser.add_argument(
            '--quantization-param-path',
262
            type=nullable_str,
263
264
265
266
267
268
269
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
270
            'supported for common inference criteria.')
271
272
        parser.add_argument('--max-model-len',
                            type=int,
273
                            default=EngineArgs.max_model_len,
274
275
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
276
277
278
279
280
281
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
282
283
284
285
286
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
287
        # Parallel arguments
288
289
290
291
292
293
294
295
296
297
298
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
299
300
301
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
302
                            default=EngineArgs.pipeline_parallel_size,
303
                            help='Number of pipeline stages.')
304
305
306
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
307
                            default=EngineArgs.tensor_parallel_size,
308
                            help='Number of tensor parallel replicas.')
309
310
311
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
312
            default=EngineArgs.max_parallel_loading_workers,
313
            help='Load model sequentially in multiple batches, '
314
            'to avoid RAM OOM when using tensor '
315
            'parallel and large models.')
316
317
318
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
319
            help='If specified, use nsight to profile Ray workers.')
320
        # KV cache arguments
321
322
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
323
                            default=EngineArgs.block_size,
324
                            choices=[8, 16, 32],
325
                            help='Token block size for contiguous chunks of '
326
327
                            'tokens. This is ignored on neuron devices and '
                            'set to max-model-len')
328
329
330

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
331
                            help='Enables automatic prefix caching.')
332
333
334
335
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
336
337
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
338
                            help='Use BlockSpaceMangerV2.')
339
340
341
342
343
344
345
346
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
347

348
349
350
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
351
                            help='Random seed for operations.')
352
        parser.add_argument('--swap-space',
353
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
354
                            default=EngineArgs.swap_space,
355
                            help='CPU swap space size (GiB) per GPU.')
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
            'which requires at least 26GB GPU memory. Note that this '
            'requires fast CPU-GPU interconnect, as part of the model is'
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
370
371
372
373
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
374
375
376
377
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
378
        parser.add_argument(
379
            '--num-gpu-blocks-override',
380
381
382
383
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
384
385
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
386
                            default=EngineArgs.max_num_batched_tokens,
387
388
                            help='Maximum number of batched tokens per '
                            'iteration.')
389
390
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
391
                            default=EngineArgs.max_num_seqs,
392
                            help='Maximum number of sequences per iteration.')
393
394
395
396
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
397
398
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
399
400
        parser.add_argument('--disable-log-stats',
                            action='store_true',
401
                            help='Disable logging statistics.')
402
403
404
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
405
                            type=nullable_str,
406
                            choices=[*QUANTIZATION_METHODS, None],
407
                            default=EngineArgs.quantization,
408
409
410
411
412
413
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
414
415
416
417
418
        parser.add_argument('--rope-scaling',
                            default=None,
                            type=json.loads,
                            help='RoPE scaling configuration in JSON format. '
                            'For example, {"type":"dynamic","factor":2.0}')
419
420
421
422
423
424
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
425
426
427
428
429
430
431
432
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
433
                            help='Maximum context length covered by CUDA '
434
                            'graphs. When a sequence has context length '
435
                            'larger than this, we fall back to eager mode. '
436
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
437
                            ')')
438
        parser.add_argument('--max-seq-len-to-capture',
439
440
441
442
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
443
                            'larger than this, we fall back to eager mode.')
444
445
446
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
447
                            help='See ParallelConfig.')
448
449
450
451
452
453
454
455
456
457
458
459
460
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
461
                            type=nullable_str,
462
463
464
465
466
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
508
509
510
511
512
513
514
515
516
517
518
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
519
520
521
522
523
524
525
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
526
527
528
529
530
531
532
533
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
534
535
536
537
538
539
540
541
542
543
544
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
545
546
547
548
549
550
551
552
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
                            choices=[
                                "auto", "cuda", "neuron", "cpu", "openvino",
                                "tpu", "xpu"
                            ],
                            help='Device type for vLLM execution.')
553
554
555
556
557
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
558

559
560
561
562
563
564
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
565
566
        parser.add_argument(
            '--enable-chunked-prefill',
567
568
569
570
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
571
            help='If set, the prefill requests can be chunked based on the '
572
            'max_num_batched_tokens.')
573
574
575

        parser.add_argument(
            '--speculative-model',
576
            type=nullable_str,
577
            default=EngineArgs.speculative_model,
578
579
            help=
            'The name of the draft model to be used in speculative decoding.')
580
581
582
583
584
585
586
587
588
589
590
591
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
            help='Method used to quantize the weights of speculative model.'
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
592
593
594
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
595
            default=EngineArgs.num_speculative_tokens,
596
            help='The number of speculative tokens to sample from '
597
            'the draft model in speculative decoding.')
598
599
600
601
602
603
604
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
605

606
607
        parser.add_argument(
            '--speculative-max-model-len',
608
            type=int,
609
610
611
612
613
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

614
615
616
617
618
619
620
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

621
622
623
624
625
626
627
628
629
630
631
632
633
634
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

667
668
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
669
            action=StoreBoolean,
670
            default=EngineArgs.disable_logprobs_during_spec_decoding,
671
672
            nargs="?",
            const="True",
673
674
675
676
677
678
679
680
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

681
        parser.add_argument('--model-loader-extra-config',
682
                            type=nullable_str,
683
684
685
686
687
688
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
689
690
691
692
693
694
695
696
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
            "Default to 'original/**/*' to avoid repeated loading of llama's "
            "checkpoints.")
697
        parser.add_argument(
698
            '--preemption-mode',
699
700
            type=str,
            default=None,
701
702
703
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
704

705
706
707
708
709
710
711
712
713
714
715
716
717
718
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
719
720
721
722
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
723
724
725
726
727
728

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
729
730
731
732
733
734
735
736
737
738
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
            ". It makes sense to set this only if --otlp-traces-endpoint is"
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
739

740
741
742
743
744
745
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
746
747
748
749
750
751
752
753
754
755
        parser.add_argument(
            '--override-neuron-config',
            type=lambda configs: {
                str(key): value
                for key, value in
                (config.split(':') for config in configs.split(','))
            },
            default=None,
            help="override or set neuron device configuration.")

756
        return parser
757
758

    @classmethod
759
    def from_cli_args(cls, args: argparse.Namespace):
760
761
762
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
763
764
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
765

766
    def create_engine_config(self) -> EngineConfig:
767
        # gguf file needs a specific model loader and doesn't use hf_repo
768
        if check_gguf_file(self.model):
769
            self.quantization = self.load_format = "gguf"
770
771
772
773

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
774
775
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
776
777
778
779
780
781
782
783
784
785
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")
786
787
788
789
790

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

791
        device_config = DeviceConfig(device=self.device)
792
        model_config = ModelConfig(
793
794
795
796
797
798
799
800
801
            model=self.model,
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
802
            rope_theta=self.rope_theta,
803
804
805
806
807
808
809
810
811
812
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_context_len_to_capture=self.max_context_len_to_capture,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
813
            served_model_name=self.served_model_name,
814
            limit_mm_per_prompt=self.limit_mm_per_prompt,
815
            use_async_output_proc=not self.disable_async_output_proc,
816
            override_neuron_config=self.override_neuron_config)
817
        cache_config = CacheConfig(
818
819
            block_size=self.block_size if self.device != "neuron" else
            self.max_model_len,  # neuron needs block_size = max_model_len
820
821
822
823
824
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
825
826
827
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
        )
828
        parallel_config = ParallelConfig(
829
830
831
832
833
834
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
835
836
837
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
838
            ),
839
            ray_workers_use_nsight=self.ray_workers_use_nsight,
840
            distributed_executor_backend=self.distributed_executor_backend)
841

842
843
844
845
846
847
848
849
850
851
852
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
            if use_long_context:
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
853
854
855
                has_seqlen_agnostic_layers = (
                    model_config.contains_seqlen_agnostic_layers(
                        parallel_config))
856
857
858
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
859
860
                        and not self.enable_prefix_caching
                        and not has_seqlen_agnostic_layers):
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)

878
879
880
881
882
883
        if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
            self.use_v2_block_manager = True
            logger.warning(
                "Enabled BlockSpaceManagerV2 because it is "
                "required for multi-step (--num-scheduler-steps > 1)")

884
885
886
887
888
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
889
890
            speculative_model_quantization = \
                self.speculative_model_quantization,
891
892
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
893
            num_speculative_tokens=self.num_speculative_tokens,
894
895
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
896
897
898
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
899
            disable_log_stats=self.disable_log_stats,
900
901
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
902
903
904
905
906
907
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
908
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
909
910
        )

911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
            if self.enable_chunked_prefill:
                raise ValueError("Chunked prefill is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

927
        scheduler_config = SchedulerConfig(
928
929
930
931
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
            use_v2_block_manager=self.use_v2_block_manager,
932
            num_lookahead_slots=num_lookahead_slots,
933
934
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
935
            embedding_mode=model_config.embedding_mode,
936
            is_multimodal_model=model_config.is_multimodal_model,
937
            preemption_mode=self.preemption_mode,
938
            num_scheduler_steps=self.num_scheduler_steps,
939
940
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
941
        )
942
943
944
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
945
            fully_sharded_loras=self.fully_sharded_loras,
946
            lora_extra_vocab_size=self.lora_extra_vocab_size,
947
            long_lora_scaling_factors=self.long_lora_scaling_factors,
948
949
950
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
951

952
953
954
955
956
957
958
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

959
960
961
962
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
963
            ignore_patterns=self.ignore_patterns,
964
965
        )

966
967
968
969
970
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

971
972
973
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

974
975
976
977
978
979
980
981
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
982
        observability_config = ObservabilityConfig(
983
984
985
986
987
988
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
989

990
        if (model_config.get_sliding_window() is not None
991
992
                and scheduler_config.chunked_prefill_enabled
                and not scheduler_config.use_v2_block_manager):
993
            raise ValueError(
994
995
                "Chunked prefill is not supported with sliding window. "
                "Set --disable-sliding-window to disable sliding window.")
996

997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
        return EngineConfig(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1008
            prompt_adapter_config=prompt_adapter_config,
1009
        )
1010
1011


1012
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1013
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1014
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
1015
    engine_use_ray: bool = False
1016
    disable_log_requests: bool = False
1017
1018

    @staticmethod
1019
1020
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1021
1022
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1023
1024
        parser.add_argument('--engine-use-ray',
                            action='store_true',
1025
                            help='Use Ray to start the LLM engine in a '
1026
1027
1028
1029
1030
1031
1032
                            'separate process as the server process.'
                            '(DEPRECATED. This argument is deprecated '
                            'and will be removed in a future update. '
                            'Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to force '
                            'use it. See '
                            'https://github.com/vllm-project/vllm/issues/7045.'
                            ')')
1033
1034
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1035
                            help='Disable logging requests.')
1036
        return parser
1037
1038


1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
class StoreBoolean(argparse.Action):

    def __call__(self, parser, namespace, values, option_string=None):
        if values.lower() == "true":
            setattr(namespace, self.dest, True)
        elif values.lower() == "false":
            setattr(namespace, self.dest, False)
        else:
            raise ValueError(f"Invalid boolean value: {values}. "
                             "Expected 'true' or 'false'.")


1051
1052
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1053
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1054
1055
1056


def _async_engine_args_parser():
1057
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1058
                                        async_args_only=True)