arg_utils.py 37.1 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
import warnings
5
from dataclasses import dataclass
6
from typing import List, Optional, Tuple, Union
7

8
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
9
10
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
11
                         TokenizerPoolConfig, VisionLanguageConfig)
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
from vllm.utils import str_to_int_tuple
14
15


16
17
18
19
20
21
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


22
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
23
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
24
    """Arguments for vLLM engine."""
25
    model: str
26
    served_model_name: Optional[Union[List[str]]] = None
27
    tokenizer: Optional[str] = None
28
    skip_tokenizer_init: bool = False
29
    tokenizer_mode: str = 'auto'
30
    trust_remote_code: bool = False
31
    download_dir: Optional[str] = None
32
    load_format: str = 'auto'
33
    dtype: str = 'auto'
34
    kv_cache_dtype: str = 'auto'
35
    quantization_param_path: Optional[str] = None
36
    seed: int = 0
37
    max_model_len: Optional[int] = None
38
    worker_use_ray: bool = False
39
    distributed_executor_backend: Optional[str] = None
40
41
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
42
    max_parallel_loading_workers: Optional[int] = None
43
    block_size: int = 16
44
    enable_prefix_caching: bool = False
45
    disable_sliding_window: bool = False
46
    use_v2_block_manager: bool = False
47
    swap_space: int = 4  # GiB
48
    gpu_memory_utilization: float = 0.90
49
    max_num_batched_tokens: Optional[int] = None
50
    max_num_seqs: int = 256
51
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
52
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
53
    revision: Optional[str] = None
54
    code_revision: Optional[str] = None
55
    rope_scaling: Optional[dict] = None
56
    rope_theta: Optional[float] = None
57
    tokenizer_revision: Optional[str] = None
58
    quantization: Optional[str] = None
59
    enforce_eager: bool = False
60
61
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
62
    disable_custom_all_reduce: bool = False
63
64
65
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
66
67
68
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
69
    fully_sharded_loras: bool = False
70
    lora_extra_vocab_size: int = 256
71
    long_lora_scaling_factors: Optional[Tuple[float]] = None
72
    lora_dtype: str = 'auto'
73
    max_cpu_loras: Optional[int] = None
74
    device: str = 'auto'
75
    ray_workers_use_nsight: bool = False
76
    num_gpu_blocks_override: Optional[int] = None
77
    num_lookahead_slots: int = 0
78
    model_loader_extra_config: Optional[dict] = None
79
    preemption_mode: Optional[str] = None
80

81
82
83
84
85
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
86
87
88
89
    image_processor: Optional[str] = None
    image_processor_revision: Optional[str] = None
    disable_image_processor: bool = False

90
    scheduler_delay_factor: float = 0.0
91
    enable_chunked_prefill: bool = False
92

93
    guided_decoding_backend: str = 'outlines'
94
95
96
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
    num_speculative_tokens: Optional[int] = None
97
    speculative_max_model_len: Optional[int] = None
98
    speculative_disable_by_batch_size: Optional[int] = None
99
100
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
101

102
103
    qlora_adapter_name_or_path: Optional[str] = None

104
    def __post_init__(self):
105
106
        if self.tokenizer is None:
            self.tokenizer = self.model
107

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    @staticmethod
    def add_cli_args_for_vlm(
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
        parser.add_argument('--image-input-type',
                            type=nullable_str,
                            default=None,
                            choices=[
                                t.name.lower()
                                for t in VisionLanguageConfig.ImageInputType
                            ],
                            help=('The image input type passed into vLLM.'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
            type=nullable_str,
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
        parser.add_argument(
            '--image-processor',
            type=str,
            default=EngineArgs.image_processor,
            help='Name or path of the huggingface image processor to use. '
            'If unspecified, model name or path will be used.')
        parser.add_argument(
            '--image-processor-revision',
            type=str,
            default=None,
            help='Revision of the huggingface image processor version to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
        parser.add_argument(
            '--disable-image-processor',
            action='store_true',
            help='Disables the use of image processor, even if one is defined '
            'for the model on huggingface.')

        return parser

155
156
    @staticmethod
    def add_cli_args(
157
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
158
        """Shared CLI arguments for vLLM engine."""
159

160
        # Model arguments
161
162
163
164
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
165
            help='Name or path of the huggingface model to use.')
166
167
        parser.add_argument(
            '--tokenizer',
168
            type=nullable_str,
169
            default=EngineArgs.tokenizer,
170
171
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
172
173
174
175
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
176
177
        parser.add_argument(
            '--revision',
178
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
179
            default=None,
180
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
181
182
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
183
184
        parser.add_argument(
            '--code-revision',
185
            type=nullable_str,
186
            default=None,
187
            help='The specific revision to use for the model code on '
188
189
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
190
191
        parser.add_argument(
            '--tokenizer-revision',
192
            type=nullable_str,
193
            default=None,
194
195
196
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
197
198
199
200
201
202
203
204
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
205
206
        parser.add_argument('--trust-remote-code',
                            action='store_true',
207
                            help='Trust remote code from huggingface.')
208
        parser.add_argument('--download-dir',
209
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
210
                            default=EngineArgs.download_dir,
211
                            help='Directory to download and load the weights, '
212
                            'default to the default cache dir of '
213
                            'huggingface.')
214
215
216
217
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
218
            choices=[
219
220
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
                'bitsandbytes'
221
            ],
222
223
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
224
            'and fall back to the pytorch bin format if safetensors format '
225
226
227
228
229
230
231
232
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
233
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
234
235
236
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
237
238
239
240
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
241
242
243
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
244
245
246
247
248
249
250
251
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
252
253
254
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
255
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
256
            default=EngineArgs.kv_cache_dtype,
257
            help='Data type for kv cache storage. If "auto", will use model '
258
259
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
260
261
        parser.add_argument(
            '--quantization-param-path',
262
            type=nullable_str,
263
264
265
266
267
268
269
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
270
            'supported for common inference criteria.')
271
272
        parser.add_argument('--max-model-len',
                            type=int,
273
                            default=EngineArgs.max_model_len,
274
275
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
276
277
278
279
280
281
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
282
283
284
285
286
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
287
        # Parallel arguments
288
289
290
291
292
293
294
295
296
297
298
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
299
300
301
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
302
                            default=EngineArgs.pipeline_parallel_size,
303
                            help='Number of pipeline stages.')
304
305
306
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
307
                            default=EngineArgs.tensor_parallel_size,
308
                            help='Number of tensor parallel replicas.')
309
310
311
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
312
            default=EngineArgs.max_parallel_loading_workers,
313
            help='Load model sequentially in multiple batches, '
314
            'to avoid RAM OOM when using tensor '
315
            'parallel and large models.')
316
317
318
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
319
            help='If specified, use nsight to profile Ray workers.')
320
        # KV cache arguments
321
322
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
323
                            default=EngineArgs.block_size,
324
                            choices=[8, 16, 32],
325
326
                            help='Token block size for contiguous chunks of '
                            'tokens.')
327
328
329

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
330
                            help='Enables automatic prefix caching.')
331
332
333
334
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
335
336
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
337
                            help='Use BlockSpaceMangerV2.')
338
339
340
341
342
343
344
345
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
346

347
348
349
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
350
                            help='Random seed for operations.')
351
352
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
353
                            default=EngineArgs.swap_space,
354
                            help='CPU swap space size (GiB) per GPU.')
355
356
357
358
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
359
360
361
362
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
363
        parser.add_argument(
364
            '--num-gpu-blocks-override',
365
366
367
368
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
369
370
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
371
                            default=EngineArgs.max_num_batched_tokens,
372
373
                            help='Maximum number of batched tokens per '
                            'iteration.')
374
375
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
376
                            default=EngineArgs.max_num_seqs,
377
                            help='Maximum number of sequences per iteration.')
378
379
380
381
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
382
383
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
384
385
        parser.add_argument('--disable-log-stats',
                            action='store_true',
386
                            help='Disable logging statistics.')
387
388
389
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
390
                            type=nullable_str,
391
                            choices=[*QUANTIZATION_METHODS, None],
392
                            default=EngineArgs.quantization,
393
394
395
396
397
398
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
399
400
401
402
403
        parser.add_argument('--rope-scaling',
                            default=None,
                            type=json.loads,
                            help='RoPE scaling configuration in JSON format. '
                            'For example, {"type":"dynamic","factor":2.0}')
404
405
406
407
408
409
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
410
411
412
413
414
415
416
417
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
418
                            help='Maximum context length covered by CUDA '
419
                            'graphs. When a sequence has context length '
420
                            'larger than this, we fall back to eager mode. '
421
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
422
                            ')')
423
        parser.add_argument('--max-seq-len-to-capture',
424
425
426
427
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
428
                            'larger than this, we fall back to eager mode.')
429
430
431
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
432
                            help='See ParallelConfig.')
433
434
435
436
437
438
439
440
441
442
443
444
445
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
446
                            type=nullable_str,
447
448
449
450
451
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
478
479
480
481
482
483
484
485
486
487
488
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
489
490
491
492
493
494
495
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
496
497
498
499
500
501
502
503
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
504
505
506
507
508
509
        parser.add_argument(
            "--device",
            type=str,
            default=EngineArgs.device,
            choices=["auto", "cuda", "neuron", "cpu", "tpu", "xpu"],
            help='Device type for vLLM execution.')
510

511
        # Related to Vision-language models such as llava
512
513
        parser = EngineArgs.add_cli_args_for_vlm(parser)

514
515
516
517
518
519
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
520
521
        parser.add_argument(
            '--enable-chunked-prefill',
522
523
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
524
            'max_num_batched_tokens.')
525
526
527

        parser.add_argument(
            '--speculative-model',
528
            type=nullable_str,
529
            default=EngineArgs.speculative_model,
530
531
532
533
534
            help=
            'The name of the draft model to be used in speculative decoding.')
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
535
            default=EngineArgs.num_speculative_tokens,
536
            help='The number of speculative tokens to sample from '
537
            'the draft model in speculative decoding.')
538

539
540
        parser.add_argument(
            '--speculative-max-model-len',
541
            type=int,
542
543
544
545
546
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

547
548
549
550
551
552
553
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

554
555
556
557
558
559
560
561
562
563
564
565
566
567
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

568
        parser.add_argument('--model-loader-extra-config',
569
                            type=nullable_str,
570
571
572
573
574
575
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
576
577
578
579
580
581
582
        parser.add_argument(
            '--preemption_mode',
            type=str,
            default=None,
            help='If \'recompute\', the engine performs preemption by block '
            'swapping; If \'swap\', the engine performs preemption by block '
            'swapping.')
583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
598
599
600
601
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
602
        return parser
603
604

    @classmethod
605
    def from_cli_args(cls, args: argparse.Namespace):
606
607
608
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
609
610
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
611

612
    def create_engine_config(self, ) -> EngineConfig:
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

630
        device_config = DeviceConfig(device=self.device)
631
        model_config = ModelConfig(
632
633
634
635
636
637
638
639
640
            model=self.model,
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
641
            rope_theta=self.rope_theta,
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_context_len_to_capture=self.max_context_len_to_capture,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
            served_model_name=self.served_model_name)
        cache_config = CacheConfig(
            block_size=self.block_size,
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
            enable_prefix_caching=self.enable_prefix_caching)
661
        parallel_config = ParallelConfig(
662
663
664
665
666
667
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
668
669
670
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
671
            ),
672
            ray_workers_use_nsight=self.ray_workers_use_nsight,
673
            distributed_executor_backend=self.distributed_executor_backend)
674
675
676
677
678
679
680

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
681
682
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
683
684
685
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
686
687
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
688
689
        )

690
        scheduler_config = SchedulerConfig(
691
692
693
694
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
            use_v2_block_manager=self.use_v2_block_manager,
695
696
697
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
698
699
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
700
            embedding_mode=model_config.embedding_mode,
701
            preemption_mode=self.preemption_mode,
702
        )
703
704
705
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
706
            fully_sharded_loras=self.fully_sharded_loras,
707
            lora_extra_vocab_size=self.lora_extra_vocab_size,
708
            long_lora_scaling_factors=self.long_lora_scaling_factors,
709
710
711
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
712

713
714
715
716
717
718
719
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

720
721
722
723
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
724
725
        )

726
727
728
729
730
731
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
732
733
734
735
736
737
738
739
740
741
742
743
744

            if self.image_processor is None:
                self.image_processor = self.model
            if self.disable_image_processor:
                if self.image_processor != self.model:
                    warnings.warn(
                        "You've specified an image processor "
                        f"({self.image_processor}) but also disabled "
                        "it via `--disable-image-processor`.",
                        stacklevel=2)

                self.image_processor = None

745
746
747
748
749
750
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
751
752
                image_processor=self.image_processor,
                image_processor_revision=self.image_processor_revision,
753
754
755
756
            )
        else:
            vision_language_config = None

757
758
759
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

760
        if (model_config.get_sliding_window() is not None
761
762
                and scheduler_config.chunked_prefill_enabled
                and not scheduler_config.use_v2_block_manager):
763
            raise ValueError(
764
765
                "Chunked prefill is not supported with sliding window. "
                "Set --disable-sliding-window to disable sliding window.")
766

767
768
769
770
771
772
773
        return EngineConfig(model_config=model_config,
                            cache_config=cache_config,
                            parallel_config=parallel_config,
                            scheduler_config=scheduler_config,
                            device_config=device_config,
                            lora_config=lora_config,
                            vision_language_config=vision_language_config,
774
                            speculative_config=speculative_config,
775
776
                            load_config=load_config,
                            decoding_config=decoding_config)
777
778


779
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
780
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
781
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
782
    engine_use_ray: bool = False
783
    disable_log_requests: bool = False
784
    max_log_len: Optional[int] = None
785
786

    @staticmethod
787
788
789
790
    def add_cli_args(parser: argparse.ArgumentParser,
                     async_args_only: bool = False) -> argparse.ArgumentParser:
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
791
792
        parser.add_argument('--engine-use-ray',
                            action='store_true',
793
                            help='Use Ray to start the LLM engine in a '
794
795
796
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
797
                            help='Disable logging requests.')
798
799
800
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
801
802
803
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
804
        return parser
805
806
807
808
809
810
811
812
813
814


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
    return EngineArgs.add_cli_args(argparse.ArgumentParser())


def _async_engine_args_parser():
    return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
                                        async_args_only=True)
815
816
817
818


def _vlm_engine_args_parser():
    return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser())