arg_utils.py 30.6 KB
Newer Older
1
import argparse
2
3
import dataclasses
from dataclasses import dataclass
4
from typing import List, Optional, Union
5

6
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
7
8
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
9
                         TokenizerPoolConfig, VisionLanguageConfig)
10
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
11
from vllm.utils import str_to_int_tuple
12
13


14
15
16
17
18
19
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


20
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
21
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    """Arguments for vLLM engine."""
23
    model: str
24
    served_model_name: Optional[Union[List[str]]] = None
25
    tokenizer: Optional[str] = None
26
    skip_tokenizer_init: bool = False
27
    tokenizer_mode: str = 'auto'
28
    trust_remote_code: bool = False
29
    download_dir: Optional[str] = None
30
    load_format: str = 'auto'
31
    dtype: str = 'auto'
32
    kv_cache_dtype: str = 'auto'
33
    quantization_param_path: Optional[str] = None
34
    seed: int = 0
35
    max_model_len: Optional[int] = None
36
    worker_use_ray: bool = False
37
    distributed_executor_backend: Optional[str] = None
38
39
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
40
    max_parallel_loading_workers: Optional[int] = None
41
    block_size: int = 16
42
    enable_prefix_caching: bool = False
43
    use_v2_block_manager: bool = False
44
    swap_space: int = 4  # GiB
45
    gpu_memory_utilization: float = 0.90
46
    max_num_batched_tokens: Optional[int] = None
47
    max_num_seqs: int = 256
48
    max_logprobs: int = 5  # OpenAI default value
49
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
50
    revision: Optional[str] = None
51
    code_revision: Optional[str] = None
52
    tokenizer_revision: Optional[str] = None
53
    quantization: Optional[str] = None
54
    enforce_eager: bool = False
55
56
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
57
    disable_custom_all_reduce: bool = False
58
59
60
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
61
62
63
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
64
    fully_sharded_loras: bool = False
65
66
67
    lora_extra_vocab_size: int = 256
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
68
    device: str = 'auto'
69
    ray_workers_use_nsight: bool = False
70
    num_gpu_blocks_override: Optional[int] = None
71
    num_lookahead_slots: int = 0
72
    model_loader_extra_config: Optional[dict] = None
73

74
75
76
77
78
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
79
    scheduler_delay_factor: float = 0.0
80
    enable_chunked_prefill: bool = False
81

82
    guided_decoding_backend: str = 'outlines'
83
84
85
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
    num_speculative_tokens: Optional[int] = None
86
    speculative_max_model_len: Optional[int] = None
87
    speculative_disable_by_batch_size: Optional[int] = None
88
89
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
90

91
    def __post_init__(self):
92
93
        if self.tokenizer is None:
            self.tokenizer = self.model
94
95
96

    @staticmethod
    def add_cli_args(
97
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        """Shared CLI arguments for vLLM engine."""
99

100
        # Model arguments
101
102
103
104
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
105
            help='Name or path of the huggingface model to use.')
106
107
        parser.add_argument(
            '--tokenizer',
108
            type=nullable_str,
109
            default=EngineArgs.tokenizer,
110
            help='Name or path of the huggingface tokenizer to use.')
111
112
113
114
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
115
116
        parser.add_argument(
            '--revision',
117
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
118
            default=None,
119
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
120
121
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
122
123
        parser.add_argument(
            '--code-revision',
124
            type=nullable_str,
125
            default=None,
126
            help='The specific revision to use for the model code on '
127
128
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
129
130
        parser.add_argument(
            '--tokenizer-revision',
131
            type=nullable_str,
132
            default=None,
133
            help='The specific tokenizer version to use. It can be a branch '
134
135
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
136
137
138
139
140
141
142
143
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
144
145
        parser.add_argument('--trust-remote-code',
                            action='store_true',
146
                            help='Trust remote code from huggingface.')
147
        parser.add_argument('--download-dir',
148
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
149
                            default=EngineArgs.download_dir,
150
                            help='Directory to download and load the weights, '
151
                            'default to the default cache dir of '
152
                            'huggingface.')
153
154
155
156
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
157
158
159
            choices=[
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
            ],
160
161
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
162
            'and fall back to the pytorch bin format if safetensors format '
163
164
165
166
167
168
169
170
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
171
172
            'CoreWeave. See the Tensorize vLLM Model script in the Examples'
            'section for more information.\n')
173
174
175
176
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
180
181
182
183
184
185
186
187
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
188
189
190
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
191
            choices=['auto', 'fp8'],
192
            default=EngineArgs.kv_cache_dtype,
193
            help='Data type for kv cache storage. If "auto", will use model '
194
195
            'data type. FP8_E5M2 (without scaling) is only supported on cuda '
            'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
196
            'supported for common inference criteria.')
197
198
        parser.add_argument(
            '--quantization-param-path',
199
            type=nullable_str,
200
201
202
203
204
205
206
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
207
            'supported for common inference criteria.')
208
209
        parser.add_argument('--max-model-len',
                            type=int,
210
                            default=EngineArgs.max_model_len,
211
212
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
213
214
215
216
217
218
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
219
220
221
222
223
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
224
        # Parallel arguments
225
226
227
228
229
230
231
232
233
234
235
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
236
237
238
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
239
                            default=EngineArgs.pipeline_parallel_size,
240
                            help='Number of pipeline stages.')
241
242
243
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
244
                            default=EngineArgs.tensor_parallel_size,
245
                            help='Number of tensor parallel replicas.')
246
247
248
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
249
            default=EngineArgs.max_parallel_loading_workers,
250
            help='Load model sequentially in multiple batches, '
251
            'to avoid RAM OOM when using tensor '
252
            'parallel and large models.')
253
254
255
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
256
            help='If specified, use nsight to profile Ray workers.')
257
        # KV cache arguments
258
259
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
260
                            default=EngineArgs.block_size,
261
                            choices=[8, 16, 32],
262
263
                            help='Token block size for contiguous chunks of '
                            'tokens.')
264
265
266

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
267
                            help='Enables automatic prefix caching.')
268
269
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
270
                            help='Use BlockSpaceMangerV2.')
271
272
273
274
275
276
277
278
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
279

280
281
282
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
283
                            help='Random seed for operations.')
284
285
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
286
                            default=EngineArgs.swap_space,
287
                            help='CPU swap space size (GiB) per GPU.')
288
289
290
291
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
292
293
294
295
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
296
        parser.add_argument(
297
            '--num-gpu-blocks-override',
298
299
300
301
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
302
303
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
304
                            default=EngineArgs.max_num_batched_tokens,
305
306
                            help='Maximum number of batched tokens per '
                            'iteration.')
307
308
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
309
                            default=EngineArgs.max_num_seqs,
310
                            help='Maximum number of sequences per iteration.')
311
312
313
314
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
315
316
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
317
318
        parser.add_argument('--disable-log-stats',
                            action='store_true',
319
                            help='Disable logging statistics.')
320
321
322
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
323
                            type=nullable_str,
324
                            choices=[*QUANTIZATION_METHODS, None],
325
                            default=EngineArgs.quantization,
326
327
328
329
330
331
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
332
333
334
335
336
337
338
339
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
340
                            help='Maximum context length covered by CUDA '
341
                            'graphs. When a sequence has context length '
342
343
344
345
346
347
348
349
                            'larger than this, we fall back to eager mode. '
                            '(DEPRECATED. Use --max-seq_len-to-capture instead'
                            ')')
        parser.add_argument('--max-seq_len-to-capture',
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
350
                            'larger than this, we fall back to eager mode.')
351
352
353
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
354
                            help='See ParallelConfig.')
355
356
357
358
359
360
361
362
363
364
365
366
367
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
368
                            type=nullable_str,
369
370
371
372
373
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
407
408
409
410
411
412
413
414
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
415
416
417
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
418
                            choices=["auto", "cuda", "neuron", "cpu"],
419
                            help='Device type for vLLM execution.')
420
421
422
        # Related to Vision-language models such as llava
        parser.add_argument(
            '--image-input-type',
423
            type=nullable_str,
424
425
426
427
428
429
430
431
432
433
434
435
            default=None,
            choices=[
                t.name.lower() for t in VisionLanguageConfig.ImageInputType
            ],
            help=('The image input type passed into vLLM. '
                  'Should be one of "pixel_values" or "image_features".'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
436
            type=nullable_str,
437
438
439
440
441
442
443
444
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
445
446
447
448
449
450
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
451
452
        parser.add_argument(
            '--enable-chunked-prefill',
453
454
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
455
            'max_num_batched_tokens.')
456
457
458

        parser.add_argument(
            '--speculative-model',
459
            type=nullable_str,
460
            default=EngineArgs.speculative_model,
461
462
463
464
465
466
            help=
            'The name of the draft model to be used in speculative decoding.')

        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
467
            default=EngineArgs.num_speculative_tokens,
468
            help='The number of speculative tokens to sample from '
469
            'the draft model in speculative decoding.')
470

471
472
        parser.add_argument(
            '--speculative-max-model-len',
473
            type=int,
474
475
476
477
478
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

479
480
481
482
483
484
485
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

486
487
488
489
490
491
492
493
494
495
496
497
498
499
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

500
        parser.add_argument('--model-loader-extra-config',
501
                            type=nullable_str,
502
503
504
505
506
507
508
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")

524
        return parser
525
526

    @classmethod
527
    def from_cli_args(cls, args: argparse.Namespace):
528
529
530
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
531
532
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
533

534
    def create_engine_config(self, ) -> EngineConfig:
535
        device_config = DeviceConfig(self.device)
536
537
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
538
539
540
541
            self.trust_remote_code, self.dtype, self.seed, self.revision,
            self.code_revision, self.tokenizer_revision, self.max_model_len,
            self.quantization, self.quantization_param_path,
            self.enforce_eager, self.max_context_len_to_capture,
542
            self.max_seq_len_to_capture, self.max_logprobs,
543
            self.skip_tokenizer_init, self.served_model_name)
544
545
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
546
                                   self.swap_space, self.kv_cache_dtype,
547
                                   self.num_gpu_blocks_override,
548
549
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
550
551
552
553
554
555
556
557
558
        parallel_config = ParallelConfig(
            self.pipeline_parallel_size, self.tensor_parallel_size,
            self.worker_use_ray, self.max_parallel_loading_workers,
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
            ), self.ray_workers_use_nsight)
559
560
561
562
563
564
565

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
566
567
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
568
569
570
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
571
572
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
573
574
        )

575
576
577
578
579
        scheduler_config = SchedulerConfig(
            self.max_num_batched_tokens,
            self.max_num_seqs,
            model_config.max_model_len,
            self.use_v2_block_manager,
580
581
582
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
583
584
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
585
            embedding_mode=model_config.embedding_mode,
586
        )
587
588
589
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
590
            fully_sharded_loras=self.fully_sharded_loras,
591
592
593
594
            lora_extra_vocab_size=self.lora_extra_vocab_size,
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
595

596
597
598
599
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
600
601
        )

602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
            )
        else:
            vision_language_config = None

618
619
620
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

621
622
623
624
625
        if (model_config.get_sliding_window() is not None
                and scheduler_config.chunked_prefill_enabled):
            raise ValueError(
                "Chunked prefill is not supported with sliding window.")

626
627
628
629
630
631
632
        return EngineConfig(model_config=model_config,
                            cache_config=cache_config,
                            parallel_config=parallel_config,
                            scheduler_config=scheduler_config,
                            device_config=device_config,
                            lora_config=lora_config,
                            vision_language_config=vision_language_config,
633
                            speculative_config=speculative_config,
634
635
                            load_config=load_config,
                            decoding_config=decoding_config)
636
637


638
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
639
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
640
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
641
    engine_use_ray: bool = False
642
    disable_log_requests: bool = False
643
    max_log_len: Optional[int] = None
644
645

    @staticmethod
646
647
648
649
    def add_cli_args(parser: argparse.ArgumentParser,
                     async_args_only: bool = False) -> argparse.ArgumentParser:
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
650
651
        parser.add_argument('--engine-use-ray',
                            action='store_true',
652
                            help='Use Ray to start the LLM engine in a '
653
654
655
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
656
                            help='Disable logging requests.')
657
658
659
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
660
661
662
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
663
        return parser
664
665
666
667
668
669
670
671
672
673


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
    return EngineArgs.add_cli_args(argparse.ArgumentParser())


def _async_engine_args_parser():
    return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
                                        async_args_only=True)