arg_utils.py 39 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
6

7
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
8
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
9
                         MultiModalConfig, ObservabilityConfig, ParallelConfig,
10
11
                         PromptAdapterConfig, SchedulerConfig,
                         SpeculativeConfig, TokenizerPoolConfig)
12
from vllm.executor.executor_base import ExecutorBase
13
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
14
from vllm.utils import FlexibleArgumentParser
15

16
17
18
19
if TYPE_CHECKING:
    from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
        BaseTokenizerGroup)

20

21
22
23
24
25
26
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


27
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
28
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
29
    """Arguments for vLLM engine."""
30
    model: str
31
    served_model_name: Optional[Union[List[str]]] = None
32
    tokenizer: Optional[str] = None
33
    skip_tokenizer_init: bool = False
34
    tokenizer_mode: str = 'auto'
35
    trust_remote_code: bool = False
36
    download_dir: Optional[str] = None
37
    load_format: str = 'auto'
38
    dtype: str = 'auto'
39
    kv_cache_dtype: str = 'auto'
40
    quantization_param_path: Optional[str] = None
41
    seed: int = 0
42
    max_model_len: Optional[int] = None
43
    worker_use_ray: bool = False
44
45
46
47
48
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
49
50
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
51
    max_parallel_loading_workers: Optional[int] = None
52
    block_size: int = 16
53
    enable_prefix_caching: bool = False
54
    disable_sliding_window: bool = False
55
    use_v2_block_manager: bool = False
56
    swap_space: int = 4  # GiB
57
    cpu_offload_gb: int = 0  # GiB
58
    gpu_memory_utilization: float = 0.90
59
    max_num_batched_tokens: Optional[int] = None
60
    max_num_seqs: int = 256
61
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
62
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
63
    revision: Optional[str] = None
64
    code_revision: Optional[str] = None
65
    rope_scaling: Optional[dict] = None
66
    rope_theta: Optional[float] = None
67
    tokenizer_revision: Optional[str] = None
68
    quantization: Optional[str] = None
69
    enforce_eager: bool = False
70
71
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
72
    disable_custom_all_reduce: bool = False
73
    tokenizer_pool_size: int = 0
74
75
76
77
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
78
    tokenizer_pool_extra_config: Optional[dict] = None
79
80
81
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
82
83
84
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
85
    fully_sharded_loras: bool = False
86
    lora_extra_vocab_size: int = 256
87
    long_lora_scaling_factors: Optional[Tuple[float]] = None
88
    lora_dtype: str = 'auto'
89
    max_cpu_loras: Optional[int] = None
90
    device: str = 'auto'
91
    ray_workers_use_nsight: bool = False
92
    num_gpu_blocks_override: Optional[int] = None
93
    num_lookahead_slots: int = 0
94
    model_loader_extra_config: Optional[dict] = None
95
    preemption_mode: Optional[str] = None
96

97
    scheduler_delay_factor: float = 0.0
98
    enable_chunked_prefill: bool = False
99

100
    guided_decoding_backend: str = 'outlines'
101
102
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
103
    speculative_draft_tensor_parallel_size: Optional[int] = None
104
    num_speculative_tokens: Optional[int] = None
105
    speculative_max_model_len: Optional[int] = None
106
    speculative_disable_by_batch_size: Optional[int] = None
107
108
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
109
110
111
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
112
113
    qlora_adapter_name_or_path: Optional[str] = None

114
115
    otlp_traces_endpoint: Optional[str] = None

116
    def __post_init__(self):
117
118
        if self.tokenizer is None:
            self.tokenizer = self.model
119
120

    @staticmethod
121
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
122
        """Shared CLI arguments for vLLM engine."""
123

124
        # Model arguments
125
126
127
128
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
129
            help='Name or path of the huggingface model to use.')
130
131
        parser.add_argument(
            '--tokenizer',
132
            type=nullable_str,
133
            default=EngineArgs.tokenizer,
134
135
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
136
137
138
139
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
140
141
        parser.add_argument(
            '--revision',
142
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
143
            default=None,
144
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
145
146
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
147
148
        parser.add_argument(
            '--code-revision',
149
            type=nullable_str,
150
            default=None,
151
            help='The specific revision to use for the model code on '
152
153
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
154
155
        parser.add_argument(
            '--tokenizer-revision',
156
            type=nullable_str,
157
            default=None,
158
159
160
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
161
162
163
164
165
166
167
168
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
169
170
        parser.add_argument('--trust-remote-code',
                            action='store_true',
171
                            help='Trust remote code from huggingface.')
172
        parser.add_argument('--download-dir',
173
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
174
                            default=EngineArgs.download_dir,
175
                            help='Directory to download and load the weights, '
176
                            'default to the default cache dir of '
177
                            'huggingface.')
178
179
180
181
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
182
            choices=[
183
184
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
                'bitsandbytes'
185
            ],
186
187
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
188
            'and fall back to the pytorch bin format if safetensors format '
189
190
191
192
193
194
195
196
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
197
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
198
199
200
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
201
202
203
204
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
208
209
210
211
212
213
214
215
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
216
217
218
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
219
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
220
            default=EngineArgs.kv_cache_dtype,
221
            help='Data type for kv cache storage. If "auto", will use model '
222
223
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
224
225
        parser.add_argument(
            '--quantization-param-path',
226
            type=nullable_str,
227
228
229
230
231
232
233
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
234
            'supported for common inference criteria.')
235
236
        parser.add_argument('--max-model-len',
                            type=int,
237
                            default=EngineArgs.max_model_len,
238
239
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
240
241
242
243
244
245
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
246
247
248
249
250
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
251
        # Parallel arguments
252
253
254
255
256
257
258
259
260
261
262
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
263
264
265
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
266
                            default=EngineArgs.pipeline_parallel_size,
267
                            help='Number of pipeline stages.')
268
269
270
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
271
                            default=EngineArgs.tensor_parallel_size,
272
                            help='Number of tensor parallel replicas.')
273
274
275
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
276
            default=EngineArgs.max_parallel_loading_workers,
277
            help='Load model sequentially in multiple batches, '
278
            'to avoid RAM OOM when using tensor '
279
            'parallel and large models.')
280
281
282
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
283
            help='If specified, use nsight to profile Ray workers.')
284
        # KV cache arguments
285
286
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
287
                            default=EngineArgs.block_size,
288
                            choices=[8, 16, 32],
289
290
                            help='Token block size for contiguous chunks of '
                            'tokens.')
291
292
293

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
294
                            help='Enables automatic prefix caching.')
295
296
297
298
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
299
300
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
301
                            help='Use BlockSpaceMangerV2.')
302
303
304
305
306
307
308
309
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
310

311
312
313
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
314
                            help='Random seed for operations.')
315
316
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
317
                            default=EngineArgs.swap_space,
318
                            help='CPU swap space size (GiB) per GPU.')
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
            'which requires at least 26GB GPU memory. Note that this '
            'requires fast CPU-GPU interconnect, as part of the model is'
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
333
334
335
336
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
337
338
339
340
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
341
        parser.add_argument(
342
            '--num-gpu-blocks-override',
343
344
345
346
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
347
348
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
349
                            default=EngineArgs.max_num_batched_tokens,
350
351
                            help='Maximum number of batched tokens per '
                            'iteration.')
352
353
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
354
                            default=EngineArgs.max_num_seqs,
355
                            help='Maximum number of sequences per iteration.')
356
357
358
359
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
360
361
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
362
363
        parser.add_argument('--disable-log-stats',
                            action='store_true',
364
                            help='Disable logging statistics.')
365
366
367
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
368
                            type=nullable_str,
369
                            choices=[*QUANTIZATION_METHODS, None],
370
                            default=EngineArgs.quantization,
371
372
373
374
375
376
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
377
378
379
380
381
        parser.add_argument('--rope-scaling',
                            default=None,
                            type=json.loads,
                            help='RoPE scaling configuration in JSON format. '
                            'For example, {"type":"dynamic","factor":2.0}')
382
383
384
385
386
387
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
388
389
390
391
392
393
394
395
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
396
                            help='Maximum context length covered by CUDA '
397
                            'graphs. When a sequence has context length '
398
                            'larger than this, we fall back to eager mode. '
399
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
400
                            ')')
401
        parser.add_argument('--max-seq-len-to-capture',
402
403
404
405
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
406
                            'larger than this, we fall back to eager mode.')
407
408
409
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
410
                            help='See ParallelConfig.')
411
412
413
414
415
416
417
418
419
420
421
422
423
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
424
                            type=nullable_str,
425
426
427
428
429
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
456
457
458
459
460
461
462
463
464
465
466
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
467
468
469
470
471
472
473
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
474
475
476
477
478
479
480
481
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
482
483
484
485
486
487
488
489
490
491
492
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
493
494
495
496
497
498
499
500
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
                            choices=[
                                "auto", "cuda", "neuron", "cpu", "openvino",
                                "tpu", "xpu"
                            ],
                            help='Device type for vLLM execution.')
501

502
503
504
505
506
507
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
508
509
        parser.add_argument(
            '--enable-chunked-prefill',
510
511
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
512
            'max_num_batched_tokens.')
513
514
515

        parser.add_argument(
            '--speculative-model',
516
            type=nullable_str,
517
            default=EngineArgs.speculative_model,
518
519
520
521
522
            help=
            'The name of the draft model to be used in speculative decoding.')
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
523
            default=EngineArgs.num_speculative_tokens,
524
            help='The number of speculative tokens to sample from '
525
            'the draft model in speculative decoding.')
526
527
528
529
530
531
532
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
533

534
535
        parser.add_argument(
            '--speculative-max-model-len',
536
            type=int,
537
538
539
540
541
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

542
543
544
545
546
547
548
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

549
550
551
552
553
554
555
556
557
558
559
560
561
562
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

595
        parser.add_argument('--model-loader-extra-config',
596
                            type=nullable_str,
597
598
599
600
601
602
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
603
        parser.add_argument(
604
            '--preemption-mode',
605
606
607
608
609
            type=str,
            default=None,
            help='If \'recompute\', the engine performs preemption by block '
            'swapping; If \'swap\', the engine performs preemption by block '
            'swapping.')
610

611
612
613
614
615
616
617
618
619
620
621
622
623
624
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
625
626
627
628
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
629
630
631
632
633
634
635

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')

636
        return parser
637
638

    @classmethod
639
    def from_cli_args(cls, args: argparse.Namespace):
640
641
642
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
643
644
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
645

646
    def create_engine_config(self, ) -> EngineConfig:
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")
663
664
665
666
667

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

668
        multimodal_config = MultiModalConfig()
669

670
        device_config = DeviceConfig(device=self.device)
671
        model_config = ModelConfig(
672
673
674
675
676
677
678
679
680
            model=self.model,
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
681
            rope_theta=self.rope_theta,
682
683
684
685
686
687
688
689
690
691
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_context_len_to_capture=self.max_context_len_to_capture,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
692
            served_model_name=self.served_model_name,
693
            multimodal_config=multimodal_config)
694
695
696
697
698
699
700
        cache_config = CacheConfig(
            block_size=self.block_size,
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
701
702
703
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
        )
704
        parallel_config = ParallelConfig(
705
706
707
708
709
710
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
711
712
713
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
714
            ),
715
            ray_workers_use_nsight=self.ray_workers_use_nsight,
716
            distributed_executor_backend=self.distributed_executor_backend)
717
718
719
720
721
722

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
723
724
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
725
            num_speculative_tokens=self.num_speculative_tokens,
726
727
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
728
729
730
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
731
732
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
733
734
735
736
737
738
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
739
740
        )

741
        scheduler_config = SchedulerConfig(
742
743
744
745
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
            use_v2_block_manager=self.use_v2_block_manager,
746
747
748
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
749
750
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
751
            embedding_mode=model_config.embedding_mode,
752
            preemption_mode=self.preemption_mode,
753
        )
754
755
756
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
757
            fully_sharded_loras=self.fully_sharded_loras,
758
            lora_extra_vocab_size=self.lora_extra_vocab_size,
759
            long_lora_scaling_factors=self.long_lora_scaling_factors,
760
761
762
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
763

764
765
766
767
768
769
770
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

771
772
773
774
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
775
776
        )

777
778
779
780
781
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

782
783
784
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

785
786
787
        observability_config = ObservabilityConfig(
            otlp_traces_endpoint=self.otlp_traces_endpoint)

788
        if (model_config.get_sliding_window() is not None
789
790
                and scheduler_config.chunked_prefill_enabled
                and not scheduler_config.use_v2_block_manager):
791
            raise ValueError(
792
793
                "Chunked prefill is not supported with sliding window. "
                "Set --disable-sliding-window to disable sliding window.")
794

795
796
797
798
799
800
801
        return EngineConfig(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
802
            multimodal_config=multimodal_config,
803
804
805
806
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
807
            prompt_adapter_config=prompt_adapter_config,
808
        )
809
810


811
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
812
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
813
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
814
    engine_use_ray: bool = False
815
    disable_log_requests: bool = False
816
    max_log_len: Optional[int] = None
817
818

    @staticmethod
819
820
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
821
822
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
823
824
        parser.add_argument('--engine-use-ray',
                            action='store_true',
825
                            help='Use Ray to start the LLM engine in a '
826
827
828
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
829
                            help='Disable logging requests.')
830
831
832
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
833
834
835
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
836
        return parser
837
838
839
840


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
841
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
842
843
844


def _async_engine_args_parser():
845
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
846
                                        async_args_only=True)