arg_utils.py 29.5 KB
Newer Older
1
import argparse
2
3
import dataclasses
from dataclasses import dataclass
4
from typing import List, Optional, Union
5

6
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
7
8
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
9
                         TokenizerPoolConfig, VisionLanguageConfig)
10
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
11
from vllm.utils import str_to_int_tuple
12
13


14
15
16
17
18
19
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


20
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
21
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    """Arguments for vLLM engine."""
23
    model: str
24
    served_model_name: Optional[Union[List[str]]] = None
25
    tokenizer: Optional[str] = None
26
    skip_tokenizer_init: bool = False
27
    tokenizer_mode: str = 'auto'
28
    trust_remote_code: bool = False
29
    download_dir: Optional[str] = None
30
    load_format: str = 'auto'
31
    dtype: str = 'auto'
32
    kv_cache_dtype: str = 'auto'
33
    quantization_param_path: Optional[str] = None
34
    seed: int = 0
35
    max_model_len: Optional[int] = None
36
    worker_use_ray: bool = False
37
38
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
39
    max_parallel_loading_workers: Optional[int] = None
40
    block_size: int = 16
41
    enable_prefix_caching: bool = False
42
    use_v2_block_manager: bool = False
43
    swap_space: int = 4  # GiB
44
    gpu_memory_utilization: float = 0.90
45
    max_num_batched_tokens: Optional[int] = None
46
    max_num_seqs: int = 256
47
    max_logprobs: int = 5  # OpenAI default value
48
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
49
    revision: Optional[str] = None
50
    code_revision: Optional[str] = None
51
    tokenizer_revision: Optional[str] = None
52
    quantization: Optional[str] = None
53
    enforce_eager: bool = False
54
55
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
56
    disable_custom_all_reduce: bool = False
57
58
59
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
60
61
62
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
63
    fully_sharded_loras: bool = False
64
65
66
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
67
    device: str = 'auto'
68
    ray_workers_use_nsight: bool = False
69
    num_gpu_blocks_override: Optional[int] = None
70
    num_lookahead_slots: int = 0
71
    model_loader_extra_config: Optional[dict] = None
72

73
74
75
76
77
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
78
    scheduler_delay_factor: float = 0.0
79
    enable_chunked_prefill: bool = False
80

81
    guided_decoding_backend: str = 'outlines'
82
83
84
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
    num_speculative_tokens: Optional[int] = None
85
    speculative_max_model_len: Optional[int] = None
86
87
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
88

89
    def __post_init__(self):
90
91
        if self.tokenizer is None:
            self.tokenizer = self.model
92
93
94

    @staticmethod
    def add_cli_args(
95
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        """Shared CLI arguments for vLLM engine."""
97

98
        # Model arguments
99
100
101
102
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
103
            help='Name or path of the huggingface model to use.')
104
105
        parser.add_argument(
            '--tokenizer',
106
            type=nullable_str,
107
            default=EngineArgs.tokenizer,
108
            help='Name or path of the huggingface tokenizer to use.')
109
110
111
112
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
113
114
        parser.add_argument(
            '--revision',
115
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
116
            default=None,
117
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
118
119
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
120
121
        parser.add_argument(
            '--code-revision',
122
            type=nullable_str,
123
            default=None,
124
            help='The specific revision to use for the model code on '
125
126
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
127
128
        parser.add_argument(
            '--tokenizer-revision',
129
            type=nullable_str,
130
            default=None,
131
            help='The specific tokenizer version to use. It can be a branch '
132
133
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
134
135
136
137
138
139
140
141
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
142
143
        parser.add_argument('--trust-remote-code',
                            action='store_true',
144
                            help='Trust remote code from huggingface.')
145
        parser.add_argument('--download-dir',
146
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
147
                            default=EngineArgs.download_dir,
148
                            help='Directory to download and load the weights, '
149
                            'default to the default cache dir of '
150
                            'huggingface.')
151
152
153
154
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
155
156
157
            choices=[
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
            ],
158
159
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
160
            'and fall back to the pytorch bin format if safetensors format '
161
162
163
164
165
166
167
168
169
170
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
            'CoreWeave which assumes tensorizer_uri is set to the location of '
            'the serialized weights.')
171
172
173
174
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
177
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
178
179
180
181
182
183
184
185
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
186
187
188
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
189
            choices=['auto', 'fp8'],
190
            default=EngineArgs.kv_cache_dtype,
191
            help='Data type for kv cache storage. If "auto", will use model '
192
193
            'data type. FP8_E5M2 (without scaling) is only supported on cuda '
            'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
194
            'supported for common inference criteria.')
195
196
        parser.add_argument(
            '--quantization-param-path',
197
            type=nullable_str,
198
199
200
201
202
203
204
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
205
            'supported for common inference criteria.')
206
207
        parser.add_argument('--max-model-len',
                            type=int,
208
                            default=EngineArgs.max_model_len,
209
210
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
211
212
213
214
215
216
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
217
218
219
220
221
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
222
        # Parallel arguments
223
224
        parser.add_argument('--worker-use-ray',
                            action='store_true',
225
226
                            help='Use Ray for distributed serving, will be '
                            'automatically set when using more than 1 GPU.')
227
228
229
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
230
                            default=EngineArgs.pipeline_parallel_size,
231
                            help='Number of pipeline stages.')
232
233
234
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
235
                            default=EngineArgs.tensor_parallel_size,
236
                            help='Number of tensor parallel replicas.')
237
238
239
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
240
            default=EngineArgs.max_parallel_loading_workers,
241
            help='Load model sequentially in multiple batches, '
242
            'to avoid RAM OOM when using tensor '
243
            'parallel and large models.')
244
245
246
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
247
            help='If specified, use nsight to profile Ray workers.')
248
        # KV cache arguments
249
250
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
251
                            default=EngineArgs.block_size,
252
                            choices=[8, 16, 32],
253
254
                            help='Token block size for contiguous chunks of '
                            'tokens.')
255
256
257

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
258
                            help='Enables automatic prefix caching.')
259
260
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
261
                            help='Use BlockSpaceMangerV2.')
262
263
264
265
266
267
268
269
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
270

271
272
273
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
274
                            help='Random seed for operations.')
275
276
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
277
                            default=EngineArgs.swap_space,
278
                            help='CPU swap space size (GiB) per GPU.')
279
280
281
282
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
283
284
285
286
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
287
        parser.add_argument(
288
            '--num-gpu-blocks-override',
289
290
291
292
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
293
294
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
295
                            default=EngineArgs.max_num_batched_tokens,
296
297
                            help='Maximum number of batched tokens per '
                            'iteration.')
298
299
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
300
                            default=EngineArgs.max_num_seqs,
301
                            help='Maximum number of sequences per iteration.')
302
303
304
305
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
306
307
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
308
309
        parser.add_argument('--disable-log-stats',
                            action='store_true',
310
                            help='Disable logging statistics.')
311
312
313
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
314
                            type=nullable_str,
315
                            choices=[*QUANTIZATION_METHODS, None],
316
                            default=EngineArgs.quantization,
317
318
319
320
321
322
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
323
324
325
326
327
328
329
330
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
331
                            help='Maximum context length covered by CUDA '
332
                            'graphs. When a sequence has context length '
333
334
335
336
337
338
339
340
                            'larger than this, we fall back to eager mode. '
                            '(DEPRECATED. Use --max-seq_len-to-capture instead'
                            ')')
        parser.add_argument('--max-seq_len-to-capture',
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
341
                            'larger than this, we fall back to eager mode.')
342
343
344
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
345
                            help='See ParallelConfig.')
346
347
348
349
350
351
352
353
354
355
356
357
358
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
359
                            type=nullable_str,
360
361
362
363
364
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
398
399
400
401
402
403
404
405
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
406
407
408
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
409
                            choices=["auto", "cuda", "neuron", "cpu"],
410
                            help='Device type for vLLM execution.')
411
412
413
        # Related to Vision-language models such as llava
        parser.add_argument(
            '--image-input-type',
414
            type=nullable_str,
415
416
417
418
419
420
421
422
423
424
425
426
            default=None,
            choices=[
                t.name.lower() for t in VisionLanguageConfig.ImageInputType
            ],
            help=('The image input type passed into vLLM. '
                  'Should be one of "pixel_values" or "image_features".'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
427
            type=nullable_str,
428
429
430
431
432
433
434
435
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
436
437
438
439
440
441
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
442
443
        parser.add_argument(
            '--enable-chunked-prefill',
444
445
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
446
            'max_num_batched_tokens.')
447
448
449

        parser.add_argument(
            '--speculative-model',
450
            type=nullable_str,
451
            default=EngineArgs.speculative_model,
452
453
454
455
456
457
            help=
            'The name of the draft model to be used in speculative decoding.')

        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
458
            default=EngineArgs.num_speculative_tokens,
459
            help='The number of speculative tokens to sample from '
460
            'the draft model in speculative decoding.')
461

462
463
        parser.add_argument(
            '--speculative-max-model-len',
464
            type=int,
465
466
467
468
469
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

470
471
472
473
474
475
476
477
478
479
480
481
482
483
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

484
        parser.add_argument('--model-loader-extra-config',
485
                            type=nullable_str,
486
487
488
489
490
491
492
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")

508
        return parser
509
510

    @classmethod
511
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
512
513
514
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
515
516
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
517

518
    def create_engine_config(self, ) -> EngineConfig:
519
        device_config = DeviceConfig(self.device)
520
521
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
522
523
524
525
            self.trust_remote_code, self.dtype, self.seed, self.revision,
            self.code_revision, self.tokenizer_revision, self.max_model_len,
            self.quantization, self.quantization_param_path,
            self.enforce_eager, self.max_context_len_to_capture,
526
            self.max_seq_len_to_capture, self.max_logprobs,
527
            self.skip_tokenizer_init, self.served_model_name)
528
529
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
530
                                   self.swap_space, self.kv_cache_dtype,
531
                                   self.num_gpu_blocks_override,
532
533
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
534
535
536
537
538
539
540
541
542
        parallel_config = ParallelConfig(
            self.pipeline_parallel_size, self.tensor_parallel_size,
            self.worker_use_ray, self.max_parallel_loading_workers,
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
            ), self.ray_workers_use_nsight)
543
544
545
546
547
548
549

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
550
551
552
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
553
554
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
555
556
        )

557
558
559
560
561
        scheduler_config = SchedulerConfig(
            self.max_num_batched_tokens,
            self.max_num_seqs,
            model_config.max_model_len,
            self.use_v2_block_manager,
562
563
564
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
565
566
567
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
        )
568
569
570
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
571
            fully_sharded_loras=self.fully_sharded_loras,
572
573
574
575
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
576

577
578
579
580
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
581
582
        )

583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
            )
        else:
            vision_language_config = None

599
600
601
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

602
603
604
605
606
607
608
        return EngineConfig(model_config=model_config,
                            cache_config=cache_config,
                            parallel_config=parallel_config,
                            scheduler_config=scheduler_config,
                            device_config=device_config,
                            lora_config=lora_config,
                            vision_language_config=vision_language_config,
609
                            speculative_config=speculative_config,
610
611
                            load_config=load_config,
                            decoding_config=decoding_config)
612
613


614
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
615
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
616
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
617
    engine_use_ray: bool = False
618
    disable_log_requests: bool = False
619
    max_log_len: Optional[int] = None
620
621

    @staticmethod
622
623
624
625
    def add_cli_args(parser: argparse.ArgumentParser,
                     async_args_only: bool = False) -> argparse.ArgumentParser:
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
626
627
        parser.add_argument('--engine-use-ray',
                            action='store_true',
628
                            help='Use Ray to start the LLM engine in a '
629
630
631
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
632
                            help='Disable logging requests.')
633
634
635
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
636
637
638
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
639
        return parser
640
641
642
643
644
645
646
647
648
649


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
    return EngineArgs.add_cli_args(argparse.ArgumentParser())


def _async_engine_args_parser():
    return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
                                        async_args_only=True)