arg_utils.py 28.7 KB
Newer Older
1
import argparse
2
3
import dataclasses
from dataclasses import dataclass
4
from typing import Optional
5

6
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
7
8
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
9
                         TokenizerPoolConfig, VisionLanguageConfig)
10
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
11
from vllm.utils import str_to_int_tuple
12
13


14
15
16
17
18
19
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


20
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
21
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    """Arguments for vLLM engine."""
23
    model: str
24
    tokenizer: Optional[str] = None
25
    skip_tokenizer_init: bool = False
26
    tokenizer_mode: str = 'auto'
27
    trust_remote_code: bool = False
28
    download_dir: Optional[str] = None
29
    load_format: str = 'auto'
30
    dtype: str = 'auto'
31
    kv_cache_dtype: str = 'auto'
32
    quantization_param_path: Optional[str] = None
33
    seed: int = 0
34
    max_model_len: Optional[int] = None
35
    worker_use_ray: bool = False
36
37
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
38
    max_parallel_loading_workers: Optional[int] = None
39
    block_size: int = 16
40
    enable_prefix_caching: bool = False
41
    use_v2_block_manager: bool = False
42
    swap_space: int = 4  # GiB
43
    gpu_memory_utilization: float = 0.90
44
    max_num_batched_tokens: Optional[int] = None
45
    max_num_seqs: int = 256
46
    max_logprobs: int = 5  # OpenAI default value
47
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
48
    revision: Optional[str] = None
49
    code_revision: Optional[str] = None
50
    tokenizer_revision: Optional[str] = None
51
    quantization: Optional[str] = None
52
    enforce_eager: bool = False
53
54
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
55
    disable_custom_all_reduce: bool = False
56
57
58
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
59
60
61
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
62
    fully_sharded_loras: bool = False
63
64
65
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
66
    device: str = 'auto'
67
    ray_workers_use_nsight: bool = False
68
    num_gpu_blocks_override: Optional[int] = None
69
    num_lookahead_slots: int = 0
70
    model_loader_extra_config: Optional[dict] = None
71

72
73
74
75
76
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
77
    scheduler_delay_factor: float = 0.0
78
    enable_chunked_prefill: bool = False
79

80
    guided_decoding_backend: str = 'outlines'
81
82
83
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
    num_speculative_tokens: Optional[int] = None
84
    speculative_max_model_len: Optional[int] = None
85
86
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
87

88
    def __post_init__(self):
89
90
        if self.tokenizer is None:
            self.tokenizer = self.model
91
92
93

    @staticmethod
    def add_cli_args(
94
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        """Shared CLI arguments for vLLM engine."""
96

97
        # Model arguments
98
99
100
101
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
102
            help='Name or path of the huggingface model to use.')
103
104
        parser.add_argument(
            '--tokenizer',
105
            type=nullable_str,
106
            default=EngineArgs.tokenizer,
107
            help='Name or path of the huggingface tokenizer to use.')
108
109
110
111
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
112
113
        parser.add_argument(
            '--revision',
114
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
115
            default=None,
116
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
117
118
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
119
120
        parser.add_argument(
            '--code-revision',
121
            type=nullable_str,
122
            default=None,
123
            help='The specific revision to use for the model code on '
124
125
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
126
127
        parser.add_argument(
            '--tokenizer-revision',
128
            type=nullable_str,
129
            default=None,
130
            help='The specific tokenizer version to use. It can be a branch '
131
132
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
133
134
135
136
137
138
139
140
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
141
142
        parser.add_argument('--trust-remote-code',
                            action='store_true',
143
                            help='Trust remote code from huggingface.')
144
        parser.add_argument('--download-dir',
145
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
146
                            default=EngineArgs.download_dir,
147
                            help='Directory to download and load the weights, '
148
                            'default to the default cache dir of '
149
                            'huggingface.')
150
151
152
153
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
154
155
156
            choices=[
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
            ],
157
158
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
159
            'and fall back to the pytorch bin format if safetensors format '
160
161
162
163
164
165
166
167
168
169
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
            'CoreWeave which assumes tensorizer_uri is set to the location of '
            'the serialized weights.')
170
171
172
173
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
177
178
179
180
181
182
183
184
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
185
186
187
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
188
            choices=['auto', 'fp8'],
189
            default=EngineArgs.kv_cache_dtype,
190
            help='Data type for kv cache storage. If "auto", will use model '
191
192
            'data type. FP8_E5M2 (without scaling) is only supported on cuda '
            'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
193
            'supported for common inference criteria.')
194
195
        parser.add_argument(
            '--quantization-param-path',
196
            type=nullable_str,
197
198
199
200
201
202
203
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
204
            'supported for common inference criteria.')
205
206
        parser.add_argument('--max-model-len',
                            type=int,
207
                            default=EngineArgs.max_model_len,
208
209
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
210
211
212
213
214
215
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
216
217
218
219
220
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
221
        # Parallel arguments
222
223
        parser.add_argument('--worker-use-ray',
                            action='store_true',
224
225
                            help='Use Ray for distributed serving, will be '
                            'automatically set when using more than 1 GPU.')
226
227
228
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
229
                            default=EngineArgs.pipeline_parallel_size,
230
                            help='Number of pipeline stages.')
231
232
233
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
234
                            default=EngineArgs.tensor_parallel_size,
235
                            help='Number of tensor parallel replicas.')
236
237
238
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
239
            default=EngineArgs.max_parallel_loading_workers,
240
            help='Load model sequentially in multiple batches, '
241
            'to avoid RAM OOM when using tensor '
242
            'parallel and large models.')
243
244
245
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
246
            help='If specified, use nsight to profile Ray workers.')
247
        # KV cache arguments
248
249
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
250
                            default=EngineArgs.block_size,
251
                            choices=[8, 16, 32],
252
253
                            help='Token block size for contiguous chunks of '
                            'tokens.')
254
255
256

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
257
                            help='Enables automatic prefix caching.')
258
259
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
260
                            help='Use BlockSpaceMangerV2.')
261
262
263
264
265
266
267
268
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
269

270
271
272
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
273
                            help='Random seed for operations.')
274
275
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
276
                            default=EngineArgs.swap_space,
277
                            help='CPU swap space size (GiB) per GPU.')
278
279
280
281
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
282
283
284
285
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
286
        parser.add_argument(
287
            '--num-gpu-blocks-override',
288
289
290
291
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
292
293
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
294
                            default=EngineArgs.max_num_batched_tokens,
295
296
                            help='Maximum number of batched tokens per '
                            'iteration.')
297
298
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
299
                            default=EngineArgs.max_num_seqs,
300
                            help='Maximum number of sequences per iteration.')
301
302
303
304
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
305
306
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
307
308
        parser.add_argument('--disable-log-stats',
                            action='store_true',
309
                            help='Disable logging statistics.')
310
311
312
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
313
                            type=nullable_str,
314
                            choices=[*QUANTIZATION_METHODS, None],
315
                            default=EngineArgs.quantization,
316
317
318
319
320
321
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
322
323
324
325
326
327
328
329
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
330
                            help='Maximum context length covered by CUDA '
331
                            'graphs. When a sequence has context length '
332
333
334
335
336
337
338
339
                            'larger than this, we fall back to eager mode. '
                            '(DEPRECATED. Use --max-seq_len-to-capture instead'
                            ')')
        parser.add_argument('--max-seq_len-to-capture',
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
340
                            'larger than this, we fall back to eager mode.')
341
342
343
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
344
                            help='See ParallelConfig.')
345
346
347
348
349
350
351
352
353
354
355
356
357
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
358
                            type=nullable_str,
359
360
361
362
363
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
397
398
399
400
401
402
403
404
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
405
406
407
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
408
                            choices=["auto", "cuda", "neuron", "cpu"],
409
                            help='Device type for vLLM execution.')
410
411
412
        # Related to Vision-language models such as llava
        parser.add_argument(
            '--image-input-type',
413
            type=nullable_str,
414
415
416
417
418
419
420
421
422
423
424
425
            default=None,
            choices=[
                t.name.lower() for t in VisionLanguageConfig.ImageInputType
            ],
            help=('The image input type passed into vLLM. '
                  'Should be one of "pixel_values" or "image_features".'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
426
            type=nullable_str,
427
428
429
430
431
432
433
434
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
435
436
437
438
439
440
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
441
442
        parser.add_argument(
            '--enable-chunked-prefill',
443
444
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
445
            'max_num_batched_tokens.')
446
447
448

        parser.add_argument(
            '--speculative-model',
449
            type=nullable_str,
450
            default=EngineArgs.speculative_model,
451
452
453
454
455
456
            help=
            'The name of the draft model to be used in speculative decoding.')

        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
457
            default=EngineArgs.num_speculative_tokens,
458
            help='The number of speculative tokens to sample from '
459
            'the draft model in speculative decoding.')
460

461
462
        parser.add_argument(
            '--speculative-max-model-len',
463
            type=int,
464
465
466
467
468
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

469
470
471
472
473
474
475
476
477
478
479
480
481
482
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

483
        parser.add_argument('--model-loader-extra-config',
484
                            type=nullable_str,
485
486
487
488
489
490
491
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')

492
        return parser
493
494

    @classmethod
495
    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
496
497
498
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
499
500
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
501

502
    def create_engine_config(self, ) -> EngineConfig:
503
        device_config = DeviceConfig(self.device)
504
505
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
506
507
508
509
            self.trust_remote_code, self.dtype, self.seed, self.revision,
            self.code_revision, self.tokenizer_revision, self.max_model_len,
            self.quantization, self.quantization_param_path,
            self.enforce_eager, self.max_context_len_to_capture,
510
511
            self.max_seq_len_to_capture, self.max_logprobs,
            self.skip_tokenizer_init)
512
513
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
514
                                   self.swap_space, self.kv_cache_dtype,
515
                                   self.num_gpu_blocks_override,
516
517
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
518
519
520
521
522
523
524
525
526
        parallel_config = ParallelConfig(
            self.pipeline_parallel_size, self.tensor_parallel_size,
            self.worker_use_ray, self.max_parallel_loading_workers,
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
            ), self.ray_workers_use_nsight)
527
528
529
530
531
532
533

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
534
535
536
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
537
538
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
539
540
        )

541
542
543
544
545
        scheduler_config = SchedulerConfig(
            self.max_num_batched_tokens,
            self.max_num_seqs,
            model_config.max_model_len,
            self.use_v2_block_manager,
546
547
548
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
549
550
551
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
        )
552
553
554
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
555
            fully_sharded_loras=self.fully_sharded_loras,
556
557
558
559
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
560

561
562
563
564
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
565
566
        )

567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
            )
        else:
            vision_language_config = None

583
584
585
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

586
587
588
589
590
591
592
        return EngineConfig(model_config=model_config,
                            cache_config=cache_config,
                            parallel_config=parallel_config,
                            scheduler_config=scheduler_config,
                            device_config=device_config,
                            lora_config=lora_config,
                            vision_language_config=vision_language_config,
593
                            speculative_config=speculative_config,
594
595
                            load_config=load_config,
                            decoding_config=decoding_config)
596
597


598
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
599
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
600
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
601
    engine_use_ray: bool = False
602
    disable_log_requests: bool = False
603
    max_log_len: Optional[int] = None
604
605

    @staticmethod
606
607
608
609
    def add_cli_args(parser: argparse.ArgumentParser,
                     async_args_only: bool = False) -> argparse.ArgumentParser:
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
610
611
        parser.add_argument('--engine-use-ray',
                            action='store_true',
612
                            help='Use Ray to start the LLM engine in a '
613
614
615
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
616
                            help='Disable logging requests.')
617
618
619
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
620
621
622
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
623
        return parser
624
625
626
627
628
629
630
631
632
633


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
    return EngineArgs.add_cli_args(argparse.ArgumentParser())


def _async_engine_args_parser():
    return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
                                        async_args_only=True)