arg_utils.py 33.7 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import List, Optional, Tuple, Union
6

7
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
8
9
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
10
                         TokenizerPoolConfig, VisionLanguageConfig)
11
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
12
from vllm.utils import str_to_int_tuple
13
14


15
16
17
18
19
20
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


21
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
22
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
23
    """Arguments for vLLM engine."""
24
    model: str
25
    served_model_name: Optional[Union[List[str]]] = None
26
    tokenizer: Optional[str] = None
27
    skip_tokenizer_init: bool = False
28
    tokenizer_mode: str = 'auto'
29
    trust_remote_code: bool = False
30
    download_dir: Optional[str] = None
31
    load_format: str = 'auto'
32
    dtype: str = 'auto'
33
    kv_cache_dtype: str = 'auto'
34
    quantization_param_path: Optional[str] = None
35
    seed: int = 0
36
    max_model_len: Optional[int] = None
37
    worker_use_ray: bool = False
38
    distributed_executor_backend: Optional[str] = None
39
40
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
41
    max_parallel_loading_workers: Optional[int] = None
42
    block_size: int = 16
43
    enable_prefix_caching: bool = False
44
    disable_sliding_window: bool = False
45
    use_v2_block_manager: bool = False
46
    swap_space: int = 4  # GiB
47
    gpu_memory_utilization: float = 0.90
48
    max_num_batched_tokens: Optional[int] = None
49
    max_num_seqs: int = 256
50
    max_logprobs: int = 5  # OpenAI default value
51
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
52
    revision: Optional[str] = None
53
    code_revision: Optional[str] = None
54
    rope_scaling: Optional[dict] = None
55
    tokenizer_revision: Optional[str] = None
56
    quantization: Optional[str] = None
57
    enforce_eager: bool = False
58
59
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
60
    disable_custom_all_reduce: bool = False
61
62
63
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
64
65
66
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
67
    fully_sharded_loras: bool = False
68
    lora_extra_vocab_size: int = 256
69
    long_lora_scaling_factors: Optional[Tuple[float]] = None
70
71
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
72
    device: str = 'auto'
73
    ray_workers_use_nsight: bool = False
74
    num_gpu_blocks_override: Optional[int] = None
75
    num_lookahead_slots: int = 0
76
    model_loader_extra_config: Optional[dict] = None
77

78
79
80
81
82
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
83
    scheduler_delay_factor: float = 0.0
84
    enable_chunked_prefill: bool = False
85

86
    guided_decoding_backend: str = 'outlines'
87
88
89
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
    num_speculative_tokens: Optional[int] = None
90
    speculative_max_model_len: Optional[int] = None
91
    speculative_disable_by_batch_size: Optional[int] = None
92
93
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
94

95
96
    qlora_adapter_name_or_path: Optional[str] = None

97
    def __post_init__(self):
98
99
        if self.tokenizer is None:
            self.tokenizer = self.model
100
101
102

    @staticmethod
    def add_cli_args(
103
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        """Shared CLI arguments for vLLM engine."""
105

106
        # Model arguments
107
108
109
110
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
111
            help='Name or path of the huggingface model to use.')
112
113
        parser.add_argument(
            '--tokenizer',
114
            type=nullable_str,
115
            default=EngineArgs.tokenizer,
116
            help='Name or path of the huggingface tokenizer to use.')
117
118
119
120
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
121
122
        parser.add_argument(
            '--revision',
123
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
124
            default=None,
125
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
126
127
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
128
129
        parser.add_argument(
            '--code-revision',
130
            type=nullable_str,
131
            default=None,
132
            help='The specific revision to use for the model code on '
133
134
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
135
136
        parser.add_argument(
            '--tokenizer-revision',
137
            type=nullable_str,
138
            default=None,
139
            help='The specific tokenizer version to use. It can be a branch '
140
141
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
142
143
144
145
146
147
148
149
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
150
151
        parser.add_argument('--trust-remote-code',
                            action='store_true',
152
                            help='Trust remote code from huggingface.')
153
        parser.add_argument('--download-dir',
154
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
155
                            default=EngineArgs.download_dir,
156
                            help='Directory to download and load the weights, '
157
                            'default to the default cache dir of '
158
                            'huggingface.')
159
160
161
162
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
163
            choices=[
164
165
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
                'bitsandbytes'
166
            ],
167
168
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
169
            'and fall back to the pytorch bin format if safetensors format '
170
171
172
173
174
175
176
177
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
178
            'CoreWeave. See the Tensorize vLLM Model script in the Examples'
179
180
181
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
182
183
184
185
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
188
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
189
190
191
192
193
194
195
196
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
197
198
199
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
200
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
201
            default=EngineArgs.kv_cache_dtype,
202
            help='Data type for kv cache storage. If "auto", will use model '
203
204
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
205
206
        parser.add_argument(
            '--quantization-param-path',
207
            type=nullable_str,
208
209
210
211
212
213
214
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
215
            'supported for common inference criteria.')
216
217
        parser.add_argument('--max-model-len',
                            type=int,
218
                            default=EngineArgs.max_model_len,
219
220
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
221
222
223
224
225
226
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
227
228
229
230
231
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
232
        # Parallel arguments
233
234
235
236
237
238
239
240
241
242
243
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
244
245
246
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
247
                            default=EngineArgs.pipeline_parallel_size,
248
                            help='Number of pipeline stages.')
249
250
251
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
252
                            default=EngineArgs.tensor_parallel_size,
253
                            help='Number of tensor parallel replicas.')
254
255
256
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
257
            default=EngineArgs.max_parallel_loading_workers,
258
            help='Load model sequentially in multiple batches, '
259
            'to avoid RAM OOM when using tensor '
260
            'parallel and large models.')
261
262
263
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
264
            help='If specified, use nsight to profile Ray workers.')
265
        # KV cache arguments
266
267
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
268
                            default=EngineArgs.block_size,
269
                            choices=[8, 16, 32],
270
271
                            help='Token block size for contiguous chunks of '
                            'tokens.')
272
273
274

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
275
                            help='Enables automatic prefix caching.')
276
277
278
279
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
280
281
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
282
                            help='Use BlockSpaceMangerV2.')
283
284
285
286
287
288
289
290
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
291

292
293
294
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
295
                            help='Random seed for operations.')
296
297
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
298
                            default=EngineArgs.swap_space,
299
                            help='CPU swap space size (GiB) per GPU.')
300
301
302
303
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
304
305
306
307
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
308
        parser.add_argument(
309
            '--num-gpu-blocks-override',
310
311
312
313
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
314
315
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
316
                            default=EngineArgs.max_num_batched_tokens,
317
318
                            help='Maximum number of batched tokens per '
                            'iteration.')
319
320
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
321
                            default=EngineArgs.max_num_seqs,
322
                            help='Maximum number of sequences per iteration.')
323
324
325
326
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
327
328
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
329
330
        parser.add_argument('--disable-log-stats',
                            action='store_true',
331
                            help='Disable logging statistics.')
332
333
334
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
335
                            type=nullable_str,
336
                            choices=[*QUANTIZATION_METHODS, None],
337
                            default=EngineArgs.quantization,
338
339
340
341
342
343
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
344
345
346
347
348
        parser.add_argument('--rope-scaling',
                            default=None,
                            type=json.loads,
                            help='RoPE scaling configuration in JSON format. '
                            'For example, {"type":"dynamic","factor":2.0}')
349
350
351
352
353
354
355
356
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
357
                            help='Maximum context length covered by CUDA '
358
                            'graphs. When a sequence has context length '
359
                            'larger than this, we fall back to eager mode. '
360
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
361
                            ')')
362
        parser.add_argument('--max-seq-len-to-capture',
363
364
365
366
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
367
                            'larger than this, we fall back to eager mode.')
368
369
370
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
371
                            help='See ParallelConfig.')
372
373
374
375
376
377
378
379
380
381
382
383
384
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
385
                            type=nullable_str,
386
387
388
389
390
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
417
418
419
420
421
422
423
424
425
426
427
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
428
429
430
431
432
433
434
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
435
436
437
438
439
440
441
442
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
443
444
445
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
446
                            choices=["auto", "cuda", "neuron", "cpu"],
447
                            help='Device type for vLLM execution.')
448
449
450
        # Related to Vision-language models such as llava
        parser.add_argument(
            '--image-input-type',
451
            type=nullable_str,
452
453
454
455
456
457
458
459
460
461
462
463
            default=None,
            choices=[
                t.name.lower() for t in VisionLanguageConfig.ImageInputType
            ],
            help=('The image input type passed into vLLM. '
                  'Should be one of "pixel_values" or "image_features".'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
464
            type=nullable_str,
465
466
467
468
469
470
471
472
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
473
474
475
476
477
478
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
479
480
        parser.add_argument(
            '--enable-chunked-prefill',
481
482
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
483
            'max_num_batched_tokens.')
484
485
486

        parser.add_argument(
            '--speculative-model',
487
            type=nullable_str,
488
            default=EngineArgs.speculative_model,
489
490
491
492
493
494
            help=
            'The name of the draft model to be used in speculative decoding.')

        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
495
            default=EngineArgs.num_speculative_tokens,
496
            help='The number of speculative tokens to sample from '
497
            'the draft model in speculative decoding.')
498

499
500
        parser.add_argument(
            '--speculative-max-model-len',
501
            type=int,
502
503
504
505
506
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

507
508
509
510
511
512
513
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

514
515
516
517
518
519
520
521
522
523
524
525
526
527
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

528
        parser.add_argument('--model-loader-extra-config',
529
                            type=nullable_str,
530
531
532
533
534
535
536
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')

537
538
539
540
541
542
543
544
545
546
547
548
549
550
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
551
552
553
554
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
555
        return parser
556
557

    @classmethod
558
    def from_cli_args(cls, args: argparse.Namespace):
559
560
561
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
562
563
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
564

565
    def create_engine_config(self, ) -> EngineConfig:
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

583
        device_config = DeviceConfig(self.device)
584
585
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
586
            self.trust_remote_code, self.dtype, self.seed, self.revision,
587
588
589
590
            self.code_revision, self.rope_scaling, self.tokenizer_revision,
            self.max_model_len, self.quantization,
            self.quantization_param_path, self.enforce_eager,
            self.max_context_len_to_capture, self.max_seq_len_to_capture,
591
592
            self.max_logprobs, self.disable_sliding_window,
            self.skip_tokenizer_init, self.served_model_name)
593
594
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
595
                                   self.swap_space, self.kv_cache_dtype,
596
                                   self.num_gpu_blocks_override,
597
598
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
599
        parallel_config = ParallelConfig(
600
601
602
603
            self.pipeline_parallel_size,
            self.tensor_parallel_size,
            self.worker_use_ray,
            self.max_parallel_loading_workers,
604
605
606
607
608
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
609
610
611
            ),
            self.ray_workers_use_nsight,
            distributed_executor_backend=self.distributed_executor_backend)
612
613
614
615
616
617
618

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
619
620
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
621
622
623
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
624
625
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
626
627
        )

628
629
630
631
632
        scheduler_config = SchedulerConfig(
            self.max_num_batched_tokens,
            self.max_num_seqs,
            model_config.max_model_len,
            self.use_v2_block_manager,
633
634
635
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
636
637
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
638
            embedding_mode=model_config.embedding_mode,
639
        )
640
641
642
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
643
            fully_sharded_loras=self.fully_sharded_loras,
644
            lora_extra_vocab_size=self.lora_extra_vocab_size,
645
            long_lora_scaling_factors=self.long_lora_scaling_factors,
646
647
648
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
649

650
651
652
653
654
655
656
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

657
658
659
660
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
661
662
        )

663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
            )
        else:
            vision_language_config = None

679
680
681
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

682
        if (model_config.get_sliding_window() is not None
683
684
                and scheduler_config.chunked_prefill_enabled
                and not scheduler_config.use_v2_block_manager):
685
            raise ValueError(
686
687
                "Chunked prefill is not supported with sliding window. "
                "Set --disable-sliding-window to disable sliding window.")
688

689
690
691
692
693
694
695
        return EngineConfig(model_config=model_config,
                            cache_config=cache_config,
                            parallel_config=parallel_config,
                            scheduler_config=scheduler_config,
                            device_config=device_config,
                            lora_config=lora_config,
                            vision_language_config=vision_language_config,
696
                            speculative_config=speculative_config,
697
698
                            load_config=load_config,
                            decoding_config=decoding_config)
699
700


701
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
702
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
703
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
704
    engine_use_ray: bool = False
705
    disable_log_requests: bool = False
706
    max_log_len: Optional[int] = None
707
708

    @staticmethod
709
710
711
712
    def add_cli_args(parser: argparse.ArgumentParser,
                     async_args_only: bool = False) -> argparse.ArgumentParser:
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
713
714
        parser.add_argument('--engine-use-ray',
                            action='store_true',
715
                            help='Use Ray to start the LLM engine in a '
716
717
718
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
719
                            help='Disable logging requests.')
720
721
722
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
723
724
725
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
726
        return parser
727
728
729
730
731
732
733
734
735
736


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
    return EngineArgs.add_cli_args(argparse.ArgumentParser())


def _async_engine_args_parser():
    return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
                                        async_args_only=True)