arg_utils.py 38.2 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
import warnings
5
from dataclasses import dataclass
6
from typing import List, Optional, Tuple, Union
7

8
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
9
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
10
11
12
                         ObservabilityConfig, ParallelConfig, SchedulerConfig,
                         SpeculativeConfig, TokenizerPoolConfig,
                         VisionLanguageConfig)
13
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
14
from vllm.utils import FlexibleArgumentParser, str_to_int_tuple
15
16


17
18
19
20
21
22
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


23
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
24
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
25
    """Arguments for vLLM engine."""
26
    model: str
27
    served_model_name: Optional[Union[List[str]]] = None
28
    tokenizer: Optional[str] = None
29
    skip_tokenizer_init: bool = False
30
    tokenizer_mode: str = 'auto'
31
    trust_remote_code: bool = False
32
    download_dir: Optional[str] = None
33
    load_format: str = 'auto'
34
    dtype: str = 'auto'
35
    kv_cache_dtype: str = 'auto'
36
    quantization_param_path: Optional[str] = None
37
    seed: int = 0
38
    max_model_len: Optional[int] = None
39
    worker_use_ray: bool = False
40
    distributed_executor_backend: Optional[str] = None
41
42
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
43
    max_parallel_loading_workers: Optional[int] = None
44
    block_size: int = 16
45
    enable_prefix_caching: bool = False
46
    disable_sliding_window: bool = False
47
    use_v2_block_manager: bool = False
48
    swap_space: int = 4  # GiB
49
    gpu_memory_utilization: float = 0.90
50
    max_num_batched_tokens: Optional[int] = None
51
    max_num_seqs: int = 256
52
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
53
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
54
    revision: Optional[str] = None
55
    code_revision: Optional[str] = None
56
    rope_scaling: Optional[dict] = None
57
    rope_theta: Optional[float] = None
58
    tokenizer_revision: Optional[str] = None
59
    quantization: Optional[str] = None
60
    enforce_eager: bool = False
61
62
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
63
    disable_custom_all_reduce: bool = False
64
65
66
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
67
68
69
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
70
    fully_sharded_loras: bool = False
71
    lora_extra_vocab_size: int = 256
72
    long_lora_scaling_factors: Optional[Tuple[float]] = None
73
    lora_dtype: str = 'auto'
74
    max_cpu_loras: Optional[int] = None
75
    device: str = 'auto'
76
    ray_workers_use_nsight: bool = False
77
    num_gpu_blocks_override: Optional[int] = None
78
    num_lookahead_slots: int = 0
79
    model_loader_extra_config: Optional[dict] = None
80
    preemption_mode: Optional[str] = None
81

82
83
84
85
86
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
87
88
89
90
    image_processor: Optional[str] = None
    image_processor_revision: Optional[str] = None
    disable_image_processor: bool = False

91
    scheduler_delay_factor: float = 0.0
92
    enable_chunked_prefill: bool = False
93

94
    guided_decoding_backend: str = 'outlines'
95
96
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
97
    speculative_draft_tensor_parallel_size: Optional[int] = None
98
    num_speculative_tokens: Optional[int] = None
99
    speculative_max_model_len: Optional[int] = None
100
    speculative_disable_by_batch_size: Optional[int] = None
101
102
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
103

104
105
    qlora_adapter_name_or_path: Optional[str] = None

106
107
    otlp_traces_endpoint: Optional[str] = None

108
    def __post_init__(self):
109
110
        if self.tokenizer is None:
            self.tokenizer = self.model
111

112
113
    @staticmethod
    def add_cli_args_for_vlm(
114
            parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        parser.add_argument('--image-input-type',
                            type=nullable_str,
                            default=None,
                            choices=[
                                t.name.lower()
                                for t in VisionLanguageConfig.ImageInputType
                            ],
                            help=('The image input type passed into vLLM.'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
            type=nullable_str,
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
        parser.add_argument(
            '--image-processor',
            type=str,
            default=EngineArgs.image_processor,
            help='Name or path of the huggingface image processor to use. '
            'If unspecified, model name or path will be used.')
        parser.add_argument(
            '--image-processor-revision',
            type=str,
            default=None,
            help='Revision of the huggingface image processor version to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
        parser.add_argument(
            '--disable-image-processor',
            action='store_true',
            help='Disables the use of image processor, even if one is defined '
            'for the model on huggingface.')

        return parser

159
    @staticmethod
160
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
161
        """Shared CLI arguments for vLLM engine."""
162

163
        # Model arguments
164
165
166
167
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
168
            help='Name or path of the huggingface model to use.')
169
170
        parser.add_argument(
            '--tokenizer',
171
            type=nullable_str,
172
            default=EngineArgs.tokenizer,
173
174
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
175
176
177
178
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
179
180
        parser.add_argument(
            '--revision',
181
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
182
            default=None,
183
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
184
185
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
186
187
        parser.add_argument(
            '--code-revision',
188
            type=nullable_str,
189
            default=None,
190
            help='The specific revision to use for the model code on '
191
192
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
193
194
        parser.add_argument(
            '--tokenizer-revision',
195
            type=nullable_str,
196
            default=None,
197
198
199
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
200
201
202
203
204
205
206
207
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
208
209
        parser.add_argument('--trust-remote-code',
                            action='store_true',
210
                            help='Trust remote code from huggingface.')
211
        parser.add_argument('--download-dir',
212
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
213
                            default=EngineArgs.download_dir,
214
                            help='Directory to download and load the weights, '
215
                            'default to the default cache dir of '
216
                            'huggingface.')
217
218
219
220
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
221
            choices=[
222
223
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
                'bitsandbytes'
224
            ],
225
226
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
227
            'and fall back to the pytorch bin format if safetensors format '
228
229
230
231
232
233
234
235
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
236
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
237
238
239
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
240
241
242
243
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
247
248
249
250
251
252
253
254
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
255
256
257
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
258
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
259
            default=EngineArgs.kv_cache_dtype,
260
            help='Data type for kv cache storage. If "auto", will use model '
261
262
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
263
264
        parser.add_argument(
            '--quantization-param-path',
265
            type=nullable_str,
266
267
268
269
270
271
272
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
273
            'supported for common inference criteria.')
274
275
        parser.add_argument('--max-model-len',
                            type=int,
276
                            default=EngineArgs.max_model_len,
277
278
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
279
280
281
282
283
284
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
285
286
287
288
289
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
290
        # Parallel arguments
291
292
293
294
295
296
297
298
299
300
301
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
302
303
304
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
305
                            default=EngineArgs.pipeline_parallel_size,
306
                            help='Number of pipeline stages.')
307
308
309
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
310
                            default=EngineArgs.tensor_parallel_size,
311
                            help='Number of tensor parallel replicas.')
312
313
314
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
315
            default=EngineArgs.max_parallel_loading_workers,
316
            help='Load model sequentially in multiple batches, '
317
            'to avoid RAM OOM when using tensor '
318
            'parallel and large models.')
319
320
321
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
322
            help='If specified, use nsight to profile Ray workers.')
323
        # KV cache arguments
324
325
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
326
                            default=EngineArgs.block_size,
327
                            choices=[8, 16, 32],
328
329
                            help='Token block size for contiguous chunks of '
                            'tokens.')
330
331
332

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
333
                            help='Enables automatic prefix caching.')
334
335
336
337
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
338
339
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
340
                            help='Use BlockSpaceMangerV2.')
341
342
343
344
345
346
347
348
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
349

350
351
352
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
353
                            help='Random seed for operations.')
354
355
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
356
                            default=EngineArgs.swap_space,
357
                            help='CPU swap space size (GiB) per GPU.')
358
359
360
361
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
362
363
364
365
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
366
        parser.add_argument(
367
            '--num-gpu-blocks-override',
368
369
370
371
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
372
373
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
374
                            default=EngineArgs.max_num_batched_tokens,
375
376
                            help='Maximum number of batched tokens per '
                            'iteration.')
377
378
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
379
                            default=EngineArgs.max_num_seqs,
380
                            help='Maximum number of sequences per iteration.')
381
382
383
384
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
385
386
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
387
388
        parser.add_argument('--disable-log-stats',
                            action='store_true',
389
                            help='Disable logging statistics.')
390
391
392
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
393
                            type=nullable_str,
394
                            choices=[*QUANTIZATION_METHODS, None],
395
                            default=EngineArgs.quantization,
396
397
398
399
400
401
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
402
403
404
405
406
        parser.add_argument('--rope-scaling',
                            default=None,
                            type=json.loads,
                            help='RoPE scaling configuration in JSON format. '
                            'For example, {"type":"dynamic","factor":2.0}')
407
408
409
410
411
412
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
413
414
415
416
417
418
419
420
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
421
                            help='Maximum context length covered by CUDA '
422
                            'graphs. When a sequence has context length '
423
                            'larger than this, we fall back to eager mode. '
424
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
425
                            ')')
426
        parser.add_argument('--max-seq-len-to-capture',
427
428
429
430
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
431
                            'larger than this, we fall back to eager mode.')
432
433
434
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
435
                            help='See ParallelConfig.')
436
437
438
439
440
441
442
443
444
445
446
447
448
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
449
                            type=nullable_str,
450
451
452
453
454
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
481
482
483
484
485
486
487
488
489
490
491
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
492
493
494
495
496
497
498
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
499
500
501
502
503
504
505
506
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
507
508
509
510
511
512
513
514
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
                            choices=[
                                "auto", "cuda", "neuron", "cpu", "openvino",
                                "tpu", "xpu"
                            ],
                            help='Device type for vLLM execution.')
515

516
        # Related to Vision-language models such as llava
517
518
        parser = EngineArgs.add_cli_args_for_vlm(parser)

519
520
521
522
523
524
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
525
526
        parser.add_argument(
            '--enable-chunked-prefill',
527
528
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
529
            'max_num_batched_tokens.')
530
531
532

        parser.add_argument(
            '--speculative-model',
533
            type=nullable_str,
534
            default=EngineArgs.speculative_model,
535
536
537
538
539
            help=
            'The name of the draft model to be used in speculative decoding.')
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
540
            default=EngineArgs.num_speculative_tokens,
541
            help='The number of speculative tokens to sample from '
542
            'the draft model in speculative decoding.')
543
544
545
546
547
548
549
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
550

551
552
        parser.add_argument(
            '--speculative-max-model-len',
553
            type=int,
554
555
556
557
558
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

559
560
561
562
563
564
565
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

566
567
568
569
570
571
572
573
574
575
576
577
578
579
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

580
        parser.add_argument('--model-loader-extra-config',
581
                            type=nullable_str,
582
583
584
585
586
587
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
588
        parser.add_argument(
589
            '--preemption-mode',
590
591
592
593
594
            type=str,
            default=None,
            help='If \'recompute\', the engine performs preemption by block '
            'swapping; If \'swap\', the engine performs preemption by block '
            'swapping.')
595

596
597
598
599
600
601
602
603
604
605
606
607
608
609
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
610
611
612
613
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
614
615
616
617
618
619
620

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')

621
        return parser
622
623

    @classmethod
624
    def from_cli_args(cls, args: argparse.Namespace):
625
626
627
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
628
629
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
630

631
    def create_engine_config(self, ) -> EngineConfig:
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')

            if self.image_processor is None:
                self.image_processor = self.model
            if self.disable_image_processor:
                if self.image_processor != self.model:
                    warnings.warn(
                        "You've specified an image processor "
                        f"({self.image_processor}) but also disabled "
                        "it via `--disable-image-processor`.",
                        stacklevel=2)

                self.image_processor = None

            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
                image_processor=self.image_processor,
                image_processor_revision=self.image_processor_revision,
            )
        else:
            vision_language_config = None
678

679
        device_config = DeviceConfig(device=self.device)
680
        model_config = ModelConfig(
681
682
683
684
685
686
687
688
689
            model=self.model,
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
690
            rope_theta=self.rope_theta,
691
692
693
694
695
696
697
698
699
700
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_context_len_to_capture=self.max_context_len_to_capture,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
701
702
            served_model_name=self.served_model_name,
            multimodal_config=vision_language_config)
703
704
705
706
707
708
709
710
        cache_config = CacheConfig(
            block_size=self.block_size,
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
            enable_prefix_caching=self.enable_prefix_caching)
711
        parallel_config = ParallelConfig(
712
713
714
715
716
717
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
718
719
720
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
721
            ),
722
            ray_workers_use_nsight=self.ray_workers_use_nsight,
723
            distributed_executor_backend=self.distributed_executor_backend)
724
725
726
727
728
729

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
730
731
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
732
            num_speculative_tokens=self.num_speculative_tokens,
733
734
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
735
736
737
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
738
739
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
740
741
        )

742
        scheduler_config = SchedulerConfig(
743
744
745
746
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
            use_v2_block_manager=self.use_v2_block_manager,
747
748
749
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
750
751
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
752
            embedding_mode=model_config.embedding_mode,
753
            preemption_mode=self.preemption_mode,
754
        )
755
756
757
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
758
            fully_sharded_loras=self.fully_sharded_loras,
759
            lora_extra_vocab_size=self.lora_extra_vocab_size,
760
            long_lora_scaling_factors=self.long_lora_scaling_factors,
761
762
763
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
764

765
766
767
768
769
770
771
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

772
773
774
775
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
776
777
        )

778
779
780
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

781
782
783
        observability_config = ObservabilityConfig(
            otlp_traces_endpoint=self.otlp_traces_endpoint)

784
        if (model_config.get_sliding_window() is not None
785
786
                and scheduler_config.chunked_prefill_enabled
                and not scheduler_config.use_v2_block_manager):
787
            raise ValueError(
788
789
                "Chunked prefill is not supported with sliding window. "
                "Set --disable-sliding-window to disable sliding window.")
790

791
792
793
794
795
796
797
798
799
800
801
802
803
        return EngineConfig(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            vision_language_config=vision_language_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
        )
804
805


806
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
807
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
808
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
809
    engine_use_ray: bool = False
810
    disable_log_requests: bool = False
811
    max_log_len: Optional[int] = None
812
813

    @staticmethod
814
815
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
816
817
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
818
819
        parser.add_argument('--engine-use-ray',
                            action='store_true',
820
                            help='Use Ray to start the LLM engine in a '
821
822
823
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
824
                            help='Disable logging requests.')
825
826
827
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
828
829
830
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
831
        return parser
832
833
834
835


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
836
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
837
838
839


def _async_engine_args_parser():
840
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
841
                                        async_args_only=True)
842
843
844


def _vlm_engine_args_parser():
845
    return EngineArgs.add_cli_args_for_vlm(FlexibleArgumentParser())