arg_utils.py 35.7 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
import warnings
5
from dataclasses import dataclass
6
from typing import List, Optional, Tuple, Union
7

8
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
9
10
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
11
                         TokenizerPoolConfig, VisionLanguageConfig)
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
from vllm.utils import str_to_int_tuple
14
15


16
17
18
19
20
21
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


22
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
23
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
24
    """Arguments for vLLM engine."""
25
    model: str
26
    served_model_name: Optional[Union[List[str]]] = None
27
    tokenizer: Optional[str] = None
28
    skip_tokenizer_init: bool = False
29
    tokenizer_mode: str = 'auto'
30
    trust_remote_code: bool = False
31
    download_dir: Optional[str] = None
32
    load_format: str = 'auto'
33
    dtype: str = 'auto'
34
    kv_cache_dtype: str = 'auto'
35
    quantization_param_path: Optional[str] = None
36
    seed: int = 0
37
    max_model_len: Optional[int] = None
38
    worker_use_ray: bool = False
39
    distributed_executor_backend: Optional[str] = None
40
41
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
42
    max_parallel_loading_workers: Optional[int] = None
43
    block_size: int = 16
44
    enable_prefix_caching: bool = False
45
    disable_sliding_window: bool = False
46
    use_v2_block_manager: bool = False
47
    swap_space: int = 4  # GiB
48
    gpu_memory_utilization: float = 0.90
49
    max_num_batched_tokens: Optional[int] = None
50
    max_num_seqs: int = 256
51
    max_logprobs: int = 5  # OpenAI default value
52
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
53
    revision: Optional[str] = None
54
    code_revision: Optional[str] = None
55
    rope_scaling: Optional[dict] = None
56
    tokenizer_revision: Optional[str] = None
57
    quantization: Optional[str] = None
58
    enforce_eager: bool = False
59
60
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
61
    disable_custom_all_reduce: bool = False
62
63
64
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
65
66
67
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
68
    fully_sharded_loras: bool = False
69
    lora_extra_vocab_size: int = 256
70
    long_lora_scaling_factors: Optional[Tuple[float]] = None
71
72
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
73
    device: str = 'auto'
74
    ray_workers_use_nsight: bool = False
75
    num_gpu_blocks_override: Optional[int] = None
76
    num_lookahead_slots: int = 0
77
    model_loader_extra_config: Optional[dict] = None
78

79
80
81
82
83
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
84
85
86
87
    image_processor: Optional[str] = None
    image_processor_revision: Optional[str] = None
    disable_image_processor: bool = False

88
    scheduler_delay_factor: float = 0.0
89
    enable_chunked_prefill: bool = False
90

91
    guided_decoding_backend: str = 'outlines'
92
93
94
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
    num_speculative_tokens: Optional[int] = None
95
    speculative_max_model_len: Optional[int] = None
96
    speculative_disable_by_batch_size: Optional[int] = None
97
98
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
99

100
101
    qlora_adapter_name_or_path: Optional[str] = None

102
    def __post_init__(self):
103
104
        if self.tokenizer is None:
            self.tokenizer = self.model
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    @staticmethod
    def add_cli_args_for_vlm(
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
        parser.add_argument('--image-input-type',
                            type=nullable_str,
                            default=None,
                            choices=[
                                t.name.lower()
                                for t in VisionLanguageConfig.ImageInputType
                            ],
                            help=('The image input type passed into vLLM.'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
            type=nullable_str,
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
        parser.add_argument(
            '--image-processor',
            type=str,
            default=EngineArgs.image_processor,
            help='Name or path of the huggingface image processor to use. '
            'If unspecified, model name or path will be used.')
        parser.add_argument(
            '--image-processor-revision',
            type=str,
            default=None,
            help='Revision of the huggingface image processor version to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
        parser.add_argument(
            '--disable-image-processor',
            action='store_true',
            help='Disables the use of image processor, even if one is defined '
            'for the model on huggingface.')

        return parser

153
154
    @staticmethod
    def add_cli_args(
155
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
156
        """Shared CLI arguments for vLLM engine."""
157

158
        # Model arguments
159
160
161
162
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
163
            help='Name or path of the huggingface model to use.')
164
165
        parser.add_argument(
            '--tokenizer',
166
            type=nullable_str,
167
            default=EngineArgs.tokenizer,
168
169
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
170
171
172
173
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
174
175
        parser.add_argument(
            '--revision',
176
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
177
            default=None,
178
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
179
180
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
181
182
        parser.add_argument(
            '--code-revision',
183
            type=nullable_str,
184
            default=None,
185
            help='The specific revision to use for the model code on '
186
187
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
188
189
        parser.add_argument(
            '--tokenizer-revision',
190
            type=nullable_str,
191
            default=None,
192
193
194
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
195
196
197
198
199
200
201
202
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
203
204
        parser.add_argument('--trust-remote-code',
                            action='store_true',
205
                            help='Trust remote code from huggingface.')
206
        parser.add_argument('--download-dir',
207
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
208
                            default=EngineArgs.download_dir,
209
                            help='Directory to download and load the weights, '
210
                            'default to the default cache dir of '
211
                            'huggingface.')
212
213
214
215
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
216
            choices=[
217
218
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
                'bitsandbytes'
219
            ],
220
221
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
222
            'and fall back to the pytorch bin format if safetensors format '
223
224
225
226
227
228
229
230
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
231
            'CoreWeave. See the Tensorize vLLM Model script in the Examples'
232
233
234
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
235
236
237
238
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
241
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
242
243
244
245
246
247
248
249
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
250
251
252
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
253
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
254
            default=EngineArgs.kv_cache_dtype,
255
            help='Data type for kv cache storage. If "auto", will use model '
256
257
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
258
259
        parser.add_argument(
            '--quantization-param-path',
260
            type=nullable_str,
261
262
263
264
265
266
267
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
268
            'supported for common inference criteria.')
269
270
        parser.add_argument('--max-model-len',
                            type=int,
271
                            default=EngineArgs.max_model_len,
272
273
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
274
275
276
277
278
279
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
280
281
282
283
284
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
285
        # Parallel arguments
286
287
288
289
290
291
292
293
294
295
296
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
297
298
299
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
300
                            default=EngineArgs.pipeline_parallel_size,
301
                            help='Number of pipeline stages.')
302
303
304
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
305
                            default=EngineArgs.tensor_parallel_size,
306
                            help='Number of tensor parallel replicas.')
307
308
309
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
310
            default=EngineArgs.max_parallel_loading_workers,
311
            help='Load model sequentially in multiple batches, '
312
            'to avoid RAM OOM when using tensor '
313
            'parallel and large models.')
314
315
316
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
317
            help='If specified, use nsight to profile Ray workers.')
318
        # KV cache arguments
319
320
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
321
                            default=EngineArgs.block_size,
322
                            choices=[8, 16, 32],
323
324
                            help='Token block size for contiguous chunks of '
                            'tokens.')
325
326
327

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
328
                            help='Enables automatic prefix caching.')
329
330
331
332
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
333
334
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
335
                            help='Use BlockSpaceMangerV2.')
336
337
338
339
340
341
342
343
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
344

345
346
347
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
348
                            help='Random seed for operations.')
349
350
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
351
                            default=EngineArgs.swap_space,
352
                            help='CPU swap space size (GiB) per GPU.')
353
354
355
356
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
357
358
359
360
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
361
        parser.add_argument(
362
            '--num-gpu-blocks-override',
363
364
365
366
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
367
368
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
369
                            default=EngineArgs.max_num_batched_tokens,
370
371
                            help='Maximum number of batched tokens per '
                            'iteration.')
372
373
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
374
                            default=EngineArgs.max_num_seqs,
375
                            help='Maximum number of sequences per iteration.')
376
377
378
379
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
380
381
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
382
383
        parser.add_argument('--disable-log-stats',
                            action='store_true',
384
                            help='Disable logging statistics.')
385
386
387
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
388
                            type=nullable_str,
389
                            choices=[*QUANTIZATION_METHODS, None],
390
                            default=EngineArgs.quantization,
391
392
393
394
395
396
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
397
398
399
400
401
        parser.add_argument('--rope-scaling',
                            default=None,
                            type=json.loads,
                            help='RoPE scaling configuration in JSON format. '
                            'For example, {"type":"dynamic","factor":2.0}')
402
403
404
405
406
407
408
409
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
410
                            help='Maximum context length covered by CUDA '
411
                            'graphs. When a sequence has context length '
412
                            'larger than this, we fall back to eager mode. '
413
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
414
                            ')')
415
        parser.add_argument('--max-seq-len-to-capture',
416
417
418
419
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
420
                            'larger than this, we fall back to eager mode.')
421
422
423
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
424
                            help='See ParallelConfig.')
425
426
427
428
429
430
431
432
433
434
435
436
437
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
438
                            type=nullable_str,
439
440
441
442
443
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
470
471
472
473
474
475
476
477
478
479
480
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
481
482
483
484
485
486
487
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
488
489
490
491
492
493
494
495
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
496
497
498
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
499
                            choices=["auto", "cuda", "neuron", "cpu"],
500
                            help='Device type for vLLM execution.')
501

502
        # Related to Vision-language models such as llava
503
504
        parser = EngineArgs.add_cli_args_for_vlm(parser)

505
506
507
508
509
510
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
511
512
        parser.add_argument(
            '--enable-chunked-prefill',
513
514
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
515
            'max_num_batched_tokens.')
516
517
518

        parser.add_argument(
            '--speculative-model',
519
            type=nullable_str,
520
            default=EngineArgs.speculative_model,
521
522
523
524
525
            help=
            'The name of the draft model to be used in speculative decoding.')
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
526
            default=EngineArgs.num_speculative_tokens,
527
            help='The number of speculative tokens to sample from '
528
            'the draft model in speculative decoding.')
529

530
531
        parser.add_argument(
            '--speculative-max-model-len',
532
            type=int,
533
534
535
536
537
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

538
539
540
541
542
543
544
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

545
546
547
548
549
550
551
552
553
554
555
556
557
558
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

559
        parser.add_argument('--model-loader-extra-config',
560
                            type=nullable_str,
561
562
563
564
565
566
567
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')

568
569
570
571
572
573
574
575
576
577
578
579
580
581
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
582
583
584
585
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
586
        return parser
587
588

    @classmethod
589
    def from_cli_args(cls, args: argparse.Namespace):
590
591
592
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
593
594
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
595

596
    def create_engine_config(self, ) -> EngineConfig:
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

614
        device_config = DeviceConfig(self.device)
615
616
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
617
            self.trust_remote_code, self.dtype, self.seed, self.revision,
618
619
620
621
            self.code_revision, self.rope_scaling, self.tokenizer_revision,
            self.max_model_len, self.quantization,
            self.quantization_param_path, self.enforce_eager,
            self.max_context_len_to_capture, self.max_seq_len_to_capture,
622
623
            self.max_logprobs, self.disable_sliding_window,
            self.skip_tokenizer_init, self.served_model_name)
624
625
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
626
                                   self.swap_space, self.kv_cache_dtype,
627
                                   self.num_gpu_blocks_override,
628
629
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
630
        parallel_config = ParallelConfig(
631
632
633
634
            self.pipeline_parallel_size,
            self.tensor_parallel_size,
            self.worker_use_ray,
            self.max_parallel_loading_workers,
635
636
637
638
639
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
640
641
642
            ),
            self.ray_workers_use_nsight,
            distributed_executor_backend=self.distributed_executor_backend)
643
644
645
646
647
648
649

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
650
651
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
652
653
654
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
655
656
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
657
658
        )

659
660
661
662
663
        scheduler_config = SchedulerConfig(
            self.max_num_batched_tokens,
            self.max_num_seqs,
            model_config.max_model_len,
            self.use_v2_block_manager,
664
665
666
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
667
668
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
669
            embedding_mode=model_config.embedding_mode,
670
        )
671
672
673
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
674
            fully_sharded_loras=self.fully_sharded_loras,
675
            lora_extra_vocab_size=self.lora_extra_vocab_size,
676
            long_lora_scaling_factors=self.long_lora_scaling_factors,
677
678
679
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
680

681
682
683
684
685
686
687
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

688
689
690
691
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
692
693
        )

694
695
696
697
698
699
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
700
701
702
703
704
705
706
707
708
709
710
711
712

            if self.image_processor is None:
                self.image_processor = self.model
            if self.disable_image_processor:
                if self.image_processor != self.model:
                    warnings.warn(
                        "You've specified an image processor "
                        f"({self.image_processor}) but also disabled "
                        "it via `--disable-image-processor`.",
                        stacklevel=2)

                self.image_processor = None

713
714
715
716
717
718
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
719
720
                image_processor=self.image_processor,
                image_processor_revision=self.image_processor_revision,
721
722
723
724
            )
        else:
            vision_language_config = None

725
726
727
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

728
        if (model_config.get_sliding_window() is not None
729
730
                and scheduler_config.chunked_prefill_enabled
                and not scheduler_config.use_v2_block_manager):
731
            raise ValueError(
732
733
                "Chunked prefill is not supported with sliding window. "
                "Set --disable-sliding-window to disable sliding window.")
734

735
736
737
738
739
740
741
        return EngineConfig(model_config=model_config,
                            cache_config=cache_config,
                            parallel_config=parallel_config,
                            scheduler_config=scheduler_config,
                            device_config=device_config,
                            lora_config=lora_config,
                            vision_language_config=vision_language_config,
742
                            speculative_config=speculative_config,
743
744
                            load_config=load_config,
                            decoding_config=decoding_config)
745
746


747
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
748
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
749
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
750
    engine_use_ray: bool = False
751
    disable_log_requests: bool = False
752
    max_log_len: Optional[int] = None
753
754

    @staticmethod
755
756
757
758
    def add_cli_args(parser: argparse.ArgumentParser,
                     async_args_only: bool = False) -> argparse.ArgumentParser:
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
759
760
        parser.add_argument('--engine-use-ray',
                            action='store_true',
761
                            help='Use Ray to start the LLM engine in a '
762
763
764
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
765
                            help='Disable logging requests.')
766
767
768
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
769
770
771
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
772
        return parser
773
774
775
776
777
778
779
780
781
782


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
    return EngineArgs.add_cli_args(argparse.ArgumentParser())


def _async_engine_args_parser():
    return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
                                        async_args_only=True)
783
784
785
786


def _vlm_engine_args_parser():
    return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser())