arg_utils.py 41.9 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
6

7
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
8
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
9
                         MultiModalConfig, ObservabilityConfig, ParallelConfig,
10
11
                         PromptAdapterConfig, SchedulerConfig,
                         SpeculativeConfig, TokenizerPoolConfig)
12
from vllm.executor.executor_base import ExecutorBase
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
15
from vllm.utils import FlexibleArgumentParser
16

17
18
19
20
if TYPE_CHECKING:
    from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
        BaseTokenizerGroup)

21
22
logger = init_logger(__name__)

23

24
25
26
27
28
29
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


30
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
31
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
32
    """Arguments for vLLM engine."""
33
    model: str
34
    served_model_name: Optional[Union[List[str]]] = None
35
    tokenizer: Optional[str] = None
36
    skip_tokenizer_init: bool = False
37
    tokenizer_mode: str = 'auto'
38
    trust_remote_code: bool = False
39
    download_dir: Optional[str] = None
40
    load_format: str = 'auto'
41
    dtype: str = 'auto'
42
    kv_cache_dtype: str = 'auto'
43
    quantization_param_path: Optional[str] = None
44
    seed: int = 0
45
    max_model_len: Optional[int] = None
46
    worker_use_ray: bool = False
47
48
49
50
51
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
52
53
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
54
    max_parallel_loading_workers: Optional[int] = None
55
    block_size: int = 16
56
    enable_prefix_caching: bool = False
57
    disable_sliding_window: bool = False
58
    use_v2_block_manager: bool = False
59
    swap_space: int = 4  # GiB
60
    cpu_offload_gb: int = 0  # GiB
61
    gpu_memory_utilization: float = 0.90
62
    max_num_batched_tokens: Optional[int] = None
63
    max_num_seqs: int = 256
64
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
65
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
66
    revision: Optional[str] = None
67
    code_revision: Optional[str] = None
68
    rope_scaling: Optional[dict] = None
69
    rope_theta: Optional[float] = None
70
    tokenizer_revision: Optional[str] = None
71
    quantization: Optional[str] = None
72
    enforce_eager: bool = False
73
74
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
75
    disable_custom_all_reduce: bool = False
76
    tokenizer_pool_size: int = 0
77
78
79
80
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
81
    tokenizer_pool_extra_config: Optional[dict] = None
82
83
84
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
85
86
87
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
88
    fully_sharded_loras: bool = False
89
    lora_extra_vocab_size: int = 256
90
    long_lora_scaling_factors: Optional[Tuple[float]] = None
91
    lora_dtype: str = 'auto'
92
    max_cpu_loras: Optional[int] = None
93
    device: str = 'auto'
94
    ray_workers_use_nsight: bool = False
95
    num_gpu_blocks_override: Optional[int] = None
96
    num_lookahead_slots: int = 0
97
    model_loader_extra_config: Optional[dict] = None
98
    preemption_mode: Optional[str] = None
99

100
    scheduler_delay_factor: float = 0.0
101
    enable_chunked_prefill: Optional[bool] = None
102

103
    guided_decoding_backend: str = 'outlines'
104
105
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
106
    speculative_draft_tensor_parallel_size: Optional[int] = None
107
    num_speculative_tokens: Optional[int] = None
108
    speculative_max_model_len: Optional[int] = None
109
    speculative_disable_by_batch_size: Optional[int] = None
110
111
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
112
113
114
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
115
    qlora_adapter_name_or_path: Optional[str] = None
116
    disable_logprobs_during_spec_decoding: Optional[bool] = None
117

118
119
    otlp_traces_endpoint: Optional[str] = None

120
    def __post_init__(self):
121
122
        if self.tokenizer is None:
            self.tokenizer = self.model
123
124

    @staticmethod
125
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
126
        """Shared CLI arguments for vLLM engine."""
127

128
        # Model arguments
129
130
131
132
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
133
            help='Name or path of the huggingface model to use.')
134
135
        parser.add_argument(
            '--tokenizer',
136
            type=nullable_str,
137
            default=EngineArgs.tokenizer,
138
139
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
140
141
142
143
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
144
145
        parser.add_argument(
            '--revision',
146
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
147
            default=None,
148
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
149
150
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
151
152
        parser.add_argument(
            '--code-revision',
153
            type=nullable_str,
154
            default=None,
155
            help='The specific revision to use for the model code on '
156
157
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
158
159
        parser.add_argument(
            '--tokenizer-revision',
160
            type=nullable_str,
161
            default=None,
162
163
164
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
165
166
167
168
169
170
171
172
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
173
174
        parser.add_argument('--trust-remote-code',
                            action='store_true',
175
                            help='Trust remote code from huggingface.')
176
        parser.add_argument('--download-dir',
177
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
178
                            default=EngineArgs.download_dir,
179
                            help='Directory to download and load the weights, '
180
                            'default to the default cache dir of '
181
                            'huggingface.')
182
183
184
185
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
186
            choices=[
187
188
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
                'bitsandbytes'
189
            ],
190
191
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
192
            'and fall back to the pytorch bin format if safetensors format '
193
194
195
196
197
198
199
200
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
201
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
202
203
204
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
205
206
207
208
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
212
213
214
215
216
217
218
219
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
220
221
222
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
223
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
224
            default=EngineArgs.kv_cache_dtype,
225
            help='Data type for kv cache storage. If "auto", will use model '
226
227
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
228
229
        parser.add_argument(
            '--quantization-param-path',
230
            type=nullable_str,
231
232
233
234
235
236
237
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
238
            'supported for common inference criteria.')
239
240
        parser.add_argument('--max-model-len',
                            type=int,
241
                            default=EngineArgs.max_model_len,
242
243
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
244
245
246
247
248
249
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
250
251
252
253
254
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
255
        # Parallel arguments
256
257
258
259
260
261
262
263
264
265
266
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
267
268
269
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
270
                            default=EngineArgs.pipeline_parallel_size,
271
                            help='Number of pipeline stages.')
272
273
274
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
275
                            default=EngineArgs.tensor_parallel_size,
276
                            help='Number of tensor parallel replicas.')
277
278
279
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
280
            default=EngineArgs.max_parallel_loading_workers,
281
            help='Load model sequentially in multiple batches, '
282
            'to avoid RAM OOM when using tensor '
283
            'parallel and large models.')
284
285
286
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
287
            help='If specified, use nsight to profile Ray workers.')
288
        # KV cache arguments
289
290
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
291
                            default=EngineArgs.block_size,
292
                            choices=[8, 16, 32],
293
294
                            help='Token block size for contiguous chunks of '
                            'tokens.')
295
296
297

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
298
                            help='Enables automatic prefix caching.')
299
300
301
302
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
303
304
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
305
                            help='Use BlockSpaceMangerV2.')
306
307
308
309
310
311
312
313
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
314

315
316
317
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
318
                            help='Random seed for operations.')
319
320
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
321
                            default=EngineArgs.swap_space,
322
                            help='CPU swap space size (GiB) per GPU.')
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
            'which requires at least 26GB GPU memory. Note that this '
            'requires fast CPU-GPU interconnect, as part of the model is'
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
337
338
339
340
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
341
342
343
344
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
345
        parser.add_argument(
346
            '--num-gpu-blocks-override',
347
348
349
350
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
351
352
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
353
                            default=EngineArgs.max_num_batched_tokens,
354
355
                            help='Maximum number of batched tokens per '
                            'iteration.')
356
357
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
358
                            default=EngineArgs.max_num_seqs,
359
                            help='Maximum number of sequences per iteration.')
360
361
362
363
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
364
365
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
366
367
        parser.add_argument('--disable-log-stats',
                            action='store_true',
368
                            help='Disable logging statistics.')
369
370
371
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
372
                            type=nullable_str,
373
                            choices=[*QUANTIZATION_METHODS, None],
374
                            default=EngineArgs.quantization,
375
376
377
378
379
380
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
381
382
383
384
385
        parser.add_argument('--rope-scaling',
                            default=None,
                            type=json.loads,
                            help='RoPE scaling configuration in JSON format. '
                            'For example, {"type":"dynamic","factor":2.0}')
386
387
388
389
390
391
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
392
393
394
395
396
397
398
399
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
400
                            help='Maximum context length covered by CUDA '
401
                            'graphs. When a sequence has context length '
402
                            'larger than this, we fall back to eager mode. '
403
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
404
                            ')')
405
        parser.add_argument('--max-seq-len-to-capture',
406
407
408
409
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
410
                            'larger than this, we fall back to eager mode.')
411
412
413
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
414
                            help='See ParallelConfig.')
415
416
417
418
419
420
421
422
423
424
425
426
427
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
428
                            type=nullable_str,
429
430
431
432
433
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
460
461
462
463
464
465
466
467
468
469
470
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
471
472
473
474
475
476
477
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
478
479
480
481
482
483
484
485
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
486
487
488
489
490
491
492
493
494
495
496
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
497
498
499
500
501
502
503
504
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
                            choices=[
                                "auto", "cuda", "neuron", "cpu", "openvino",
                                "tpu", "xpu"
                            ],
                            help='Device type for vLLM execution.')
505

506
507
508
509
510
511
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
512
513
        parser.add_argument(
            '--enable-chunked-prefill',
514
515
516
517
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
518
            help='If set, the prefill requests can be chunked based on the '
519
            'max_num_batched_tokens.')
520
521
522

        parser.add_argument(
            '--speculative-model',
523
            type=nullable_str,
524
            default=EngineArgs.speculative_model,
525
526
527
528
529
            help=
            'The name of the draft model to be used in speculative decoding.')
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
530
            default=EngineArgs.num_speculative_tokens,
531
            help='The number of speculative tokens to sample from '
532
            'the draft model in speculative decoding.')
533
534
535
536
537
538
539
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
540

541
542
        parser.add_argument(
            '--speculative-max-model-len',
543
            type=int,
544
545
546
547
548
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

549
550
551
552
553
554
555
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

556
557
558
559
560
561
562
563
564
565
566
567
568
569
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

602
603
604
605
606
607
608
609
610
611
612
613
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
            type=bool,
            default=EngineArgs.disable_logprobs_during_spec_decoding,
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

614
        parser.add_argument('--model-loader-extra-config',
615
                            type=nullable_str,
616
617
618
619
620
621
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
622
        parser.add_argument(
623
            '--preemption-mode',
624
625
626
627
628
            type=str,
            default=None,
            help='If \'recompute\', the engine performs preemption by block '
            'swapping; If \'swap\', the engine performs preemption by block '
            'swapping.')
629

630
631
632
633
634
635
636
637
638
639
640
641
642
643
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
644
645
646
647
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
648
649
650
651
652
653
654

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')

655
        return parser
656
657

    @classmethod
658
    def from_cli_args(cls, args: argparse.Namespace):
659
660
661
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
662
663
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
664

665
    def create_engine_config(self, ) -> EngineConfig:
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")
682
683
684
685
686

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

687
        multimodal_config = MultiModalConfig()
688

689
        device_config = DeviceConfig(device=self.device)
690
        model_config = ModelConfig(
691
692
693
694
695
696
697
698
699
            model=self.model,
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
700
            rope_theta=self.rope_theta,
701
702
703
704
705
706
707
708
709
710
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_context_len_to_capture=self.max_context_len_to_capture,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
711
            served_model_name=self.served_model_name,
712
            multimodal_config=multimodal_config)
713
714
715
716
717
718
719
        cache_config = CacheConfig(
            block_size=self.block_size,
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
720
721
722
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
        )
723
        parallel_config = ParallelConfig(
724
725
726
727
728
729
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
730
731
732
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
733
            ),
734
            ray_workers_use_nsight=self.ray_workers_use_nsight,
735
            distributed_executor_backend=self.distributed_executor_backend)
736

737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
            if use_long_context:
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and not self.enable_prefix_caching):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)

769
770
771
772
773
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
774
775
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
776
            num_speculative_tokens=self.num_speculative_tokens,
777
778
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
779
780
781
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
782
783
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
784
785
786
787
788
789
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
790
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
791
792
        )

793
        scheduler_config = SchedulerConfig(
794
795
796
797
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
            use_v2_block_manager=self.use_v2_block_manager,
798
799
800
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
801
802
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
803
            embedding_mode=model_config.embedding_mode,
804
            preemption_mode=self.preemption_mode,
805
        )
806
807
808
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
809
            fully_sharded_loras=self.fully_sharded_loras,
810
            lora_extra_vocab_size=self.lora_extra_vocab_size,
811
            long_lora_scaling_factors=self.long_lora_scaling_factors,
812
813
814
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
815

816
817
818
819
820
821
822
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

823
824
825
826
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
827
828
        )

829
830
831
832
833
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

834
835
836
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

837
838
839
        observability_config = ObservabilityConfig(
            otlp_traces_endpoint=self.otlp_traces_endpoint)

840
        if (model_config.get_sliding_window() is not None
841
842
                and scheduler_config.chunked_prefill_enabled
                and not scheduler_config.use_v2_block_manager):
843
            raise ValueError(
844
845
                "Chunked prefill is not supported with sliding window. "
                "Set --disable-sliding-window to disable sliding window.")
846

847
848
849
850
851
852
853
        return EngineConfig(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
854
            multimodal_config=multimodal_config,
855
856
857
858
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
859
            prompt_adapter_config=prompt_adapter_config,
860
        )
861
862


863
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
864
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
865
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
866
    engine_use_ray: bool = False
867
    disable_log_requests: bool = False
868
869

    @staticmethod
870
871
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
872
873
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
874
875
        parser.add_argument('--engine-use-ray',
                            action='store_true',
876
                            help='Use Ray to start the LLM engine in a '
877
878
879
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
880
                            help='Disable logging requests.')
881
        return parser
882
883


884
885
886
887
888
889
890
891
892
893
894
895
class StoreBoolean(argparse.Action):

    def __call__(self, parser, namespace, values, option_string=None):
        if values.lower() == "true":
            setattr(namespace, self.dest, True)
        elif values.lower() == "false":
            setattr(namespace, self.dest, False)
        else:
            raise ValueError(f"Invalid boolean value: {values}. "
                             "Expected 'true' or 'false'.")


896
897
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
898
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
899
900
901


def _async_engine_args_parser():
902
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
903
                                        async_args_only=True)