arg_utils.py 31.8 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import List, Optional, Tuple, Union
6

7
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
8
9
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
10
                         TokenizerPoolConfig, VisionLanguageConfig)
11
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
12
from vllm.utils import str_to_int_tuple
13
14


15
16
17
18
19
20
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


21
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
22
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
23
    """Arguments for vLLM engine."""
24
    model: str
25
    served_model_name: Optional[Union[List[str]]] = None
26
    tokenizer: Optional[str] = None
27
    skip_tokenizer_init: bool = False
28
    tokenizer_mode: str = 'auto'
29
    trust_remote_code: bool = False
30
    download_dir: Optional[str] = None
31
    load_format: str = 'auto'
32
    dtype: str = 'auto'
33
    kv_cache_dtype: str = 'auto'
34
    quantization_param_path: Optional[str] = None
35
    seed: int = 0
36
    max_model_len: Optional[int] = None
37
    worker_use_ray: bool = False
38
    distributed_executor_backend: Optional[str] = None
39
40
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
41
    max_parallel_loading_workers: Optional[int] = None
42
    block_size: int = 16
43
    enable_prefix_caching: bool = False
44
    use_v2_block_manager: bool = False
45
    swap_space: int = 4  # GiB
46
    gpu_memory_utilization: float = 0.90
47
    max_num_batched_tokens: Optional[int] = None
48
    max_num_seqs: int = 256
49
    max_logprobs: int = 5  # OpenAI default value
50
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
51
    revision: Optional[str] = None
52
    code_revision: Optional[str] = None
53
    rope_scaling: Optional[dict] = None
54
    tokenizer_revision: Optional[str] = None
55
    quantization: Optional[str] = None
56
    enforce_eager: bool = False
57
58
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
59
    disable_custom_all_reduce: bool = False
60
61
62
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
63
64
65
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
66
    fully_sharded_loras: bool = False
67
    lora_extra_vocab_size: int = 256
68
    long_lora_scaling_factors: Optional[Tuple[float]] = None
69
70
    lora_dtype = 'auto'
    max_cpu_loras: Optional[int] = None
71
    device: str = 'auto'
72
    ray_workers_use_nsight: bool = False
73
    num_gpu_blocks_override: Optional[int] = None
74
    num_lookahead_slots: int = 0
75
    model_loader_extra_config: Optional[dict] = None
76

77
78
79
80
81
    # Related to Vision-language models such as llava
    image_input_type: Optional[str] = None
    image_token_id: Optional[int] = None
    image_input_shape: Optional[str] = None
    image_feature_size: Optional[int] = None
82
    scheduler_delay_factor: float = 0.0
83
    enable_chunked_prefill: bool = False
84

85
    guided_decoding_backend: str = 'outlines'
86
87
88
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
    num_speculative_tokens: Optional[int] = None
89
    speculative_max_model_len: Optional[int] = None
90
    speculative_disable_by_batch_size: Optional[int] = None
91
92
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
93

94
    def __post_init__(self):
95
96
        if self.tokenizer is None:
            self.tokenizer = self.model
97
98
99

    @staticmethod
    def add_cli_args(
100
            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
101
        """Shared CLI arguments for vLLM engine."""
102

103
        # Model arguments
104
105
106
107
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
108
            help='Name or path of the huggingface model to use.')
109
110
        parser.add_argument(
            '--tokenizer',
111
            type=nullable_str,
112
            default=EngineArgs.tokenizer,
113
            help='Name or path of the huggingface tokenizer to use.')
114
115
116
117
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
118
119
        parser.add_argument(
            '--revision',
120
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
121
            default=None,
122
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
123
124
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
125
126
        parser.add_argument(
            '--code-revision',
127
            type=nullable_str,
128
            default=None,
129
            help='The specific revision to use for the model code on '
130
131
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
132
133
        parser.add_argument(
            '--tokenizer-revision',
134
            type=nullable_str,
135
            default=None,
136
            help='The specific tokenizer version to use. It can be a branch '
137
138
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
139
140
141
142
143
144
145
146
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
147
148
        parser.add_argument('--trust-remote-code',
                            action='store_true',
149
                            help='Trust remote code from huggingface.')
150
        parser.add_argument('--download-dir',
151
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
152
                            default=EngineArgs.download_dir,
153
                            help='Directory to download and load the weights, '
154
                            'default to the default cache dir of '
155
                            'huggingface.')
156
157
158
159
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
160
161
162
            choices=[
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
            ],
163
164
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
165
            'and fall back to the pytorch bin format if safetensors format '
166
167
168
169
170
171
172
173
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
174
175
            'CoreWeave. See the Tensorize vLLM Model script in the Examples'
            'section for more information.\n')
176
177
178
179
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
182
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
183
184
185
186
187
188
189
190
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
191
192
193
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
194
            choices=['auto', 'fp8'],
195
            default=EngineArgs.kv_cache_dtype,
196
            help='Data type for kv cache storage. If "auto", will use model '
197
198
            'data type. FP8_E5M2 (without scaling) is only supported on cuda '
            'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
199
            'supported for common inference criteria.')
200
201
        parser.add_argument(
            '--quantization-param-path',
202
            type=nullable_str,
203
204
205
206
207
208
209
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
210
            'supported for common inference criteria.')
211
212
        parser.add_argument('--max-model-len',
                            type=int,
213
                            default=EngineArgs.max_model_len,
214
215
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
216
217
218
219
220
221
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
222
223
224
225
226
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
227
        # Parallel arguments
228
229
230
231
232
233
234
235
236
237
238
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
239
240
241
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
242
                            default=EngineArgs.pipeline_parallel_size,
243
                            help='Number of pipeline stages.')
244
245
246
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
247
                            default=EngineArgs.tensor_parallel_size,
248
                            help='Number of tensor parallel replicas.')
249
250
251
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
252
            default=EngineArgs.max_parallel_loading_workers,
253
            help='Load model sequentially in multiple batches, '
254
            'to avoid RAM OOM when using tensor '
255
            'parallel and large models.')
256
257
258
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
259
            help='If specified, use nsight to profile Ray workers.')
260
        # KV cache arguments
261
262
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
263
                            default=EngineArgs.block_size,
264
                            choices=[8, 16, 32],
265
266
                            help='Token block size for contiguous chunks of '
                            'tokens.')
267
268
269

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
270
                            help='Enables automatic prefix caching.')
271
272
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
273
                            help='Use BlockSpaceMangerV2.')
274
275
276
277
278
279
280
281
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
282

283
284
285
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
286
                            help='Random seed for operations.')
287
288
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
289
                            default=EngineArgs.swap_space,
290
                            help='CPU swap space size (GiB) per GPU.')
291
292
293
294
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
295
296
297
298
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
299
        parser.add_argument(
300
            '--num-gpu-blocks-override',
301
302
303
304
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
305
306
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
307
                            default=EngineArgs.max_num_batched_tokens,
308
309
                            help='Maximum number of batched tokens per '
                            'iteration.')
310
311
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
312
                            default=EngineArgs.max_num_seqs,
313
                            help='Maximum number of sequences per iteration.')
314
315
316
317
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
318
319
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
320
321
        parser.add_argument('--disable-log-stats',
                            action='store_true',
322
                            help='Disable logging statistics.')
323
324
325
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
326
                            type=nullable_str,
327
                            choices=[*QUANTIZATION_METHODS, None],
328
                            default=EngineArgs.quantization,
329
330
331
332
333
334
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
335
336
337
338
339
        parser.add_argument('--rope-scaling',
                            default=None,
                            type=json.loads,
                            help='RoPE scaling configuration in JSON format. '
                            'For example, {"type":"dynamic","factor":2.0}')
340
341
342
343
344
345
346
347
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
348
                            help='Maximum context length covered by CUDA '
349
                            'graphs. When a sequence has context length '
350
                            'larger than this, we fall back to eager mode. '
351
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
352
                            ')')
353
        parser.add_argument('--max-seq-len-to-capture',
354
355
356
357
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
358
                            'larger than this, we fall back to eager mode.')
359
360
361
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
362
                            help='See ParallelConfig.')
363
364
365
366
367
368
369
370
371
372
373
374
375
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
376
                            type=nullable_str,
377
378
379
380
381
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
408
409
410
411
412
413
414
415
416
417
418
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
419
420
421
422
423
424
425
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
426
427
428
429
430
431
432
433
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
434
435
436
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
437
                            choices=["auto", "cuda", "neuron", "cpu"],
438
                            help='Device type for vLLM execution.')
439
440
441
        # Related to Vision-language models such as llava
        parser.add_argument(
            '--image-input-type',
442
            type=nullable_str,
443
444
445
446
447
448
449
450
451
452
453
454
            default=None,
            choices=[
                t.name.lower() for t in VisionLanguageConfig.ImageInputType
            ],
            help=('The image input type passed into vLLM. '
                  'Should be one of "pixel_values" or "image_features".'))
        parser.add_argument('--image-token-id',
                            type=int,
                            default=None,
                            help=('Input id for image token.'))
        parser.add_argument(
            '--image-input-shape',
455
            type=nullable_str,
456
457
458
459
460
461
462
463
            default=None,
            help=('The biggest image input shape (worst for memory footprint) '
                  'given an input type. Only used for vLLM\'s profile_run.'))
        parser.add_argument(
            '--image-feature-size',
            type=int,
            default=None,
            help=('The image feature size along the context dimension.'))
464
465
466
467
468
469
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
470
471
        parser.add_argument(
            '--enable-chunked-prefill',
472
473
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
474
            'max_num_batched_tokens.')
475
476
477

        parser.add_argument(
            '--speculative-model',
478
            type=nullable_str,
479
            default=EngineArgs.speculative_model,
480
481
482
483
484
485
            help=
            'The name of the draft model to be used in speculative decoding.')

        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
486
            default=EngineArgs.num_speculative_tokens,
487
            help='The number of speculative tokens to sample from '
488
            'the draft model in speculative decoding.')
489

490
491
        parser.add_argument(
            '--speculative-max-model-len',
492
            type=int,
493
494
495
496
497
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

498
499
500
501
502
503
504
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

505
506
507
508
509
510
511
512
513
514
515
516
517
518
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

519
        parser.add_argument('--model-loader-extra-config',
520
                            type=nullable_str,
521
522
523
524
525
526
527
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')

528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")

543
        return parser
544
545

    @classmethod
546
    def from_cli_args(cls, args: argparse.Namespace):
547
548
549
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
550
551
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
552

553
    def create_engine_config(self, ) -> EngineConfig:
554
        device_config = DeviceConfig(self.device)
555
556
        model_config = ModelConfig(
            self.model, self.tokenizer, self.tokenizer_mode,
557
            self.trust_remote_code, self.dtype, self.seed, self.revision,
558
559
560
561
562
563
            self.code_revision, self.rope_scaling, self.tokenizer_revision,
            self.max_model_len, self.quantization,
            self.quantization_param_path, self.enforce_eager,
            self.max_context_len_to_capture, self.max_seq_len_to_capture,
            self.max_logprobs, self.skip_tokenizer_init,
            self.served_model_name)
564
565
        cache_config = CacheConfig(self.block_size,
                                   self.gpu_memory_utilization,
566
                                   self.swap_space, self.kv_cache_dtype,
567
                                   self.num_gpu_blocks_override,
568
569
                                   model_config.get_sliding_window(),
                                   self.enable_prefix_caching)
570
        parallel_config = ParallelConfig(
571
572
573
574
            self.pipeline_parallel_size,
            self.tensor_parallel_size,
            self.worker_use_ray,
            self.max_parallel_loading_workers,
575
576
577
578
579
            self.disable_custom_all_reduce,
            TokenizerPoolConfig.create_config(
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
580
581
582
            ),
            self.ray_workers_use_nsight,
            distributed_executor_backend=self.distributed_executor_backend)
583
584
585
586
587
588
589

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
            num_speculative_tokens=self.num_speculative_tokens,
590
591
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
592
593
594
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
595
596
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
597
598
        )

599
600
601
602
603
        scheduler_config = SchedulerConfig(
            self.max_num_batched_tokens,
            self.max_num_seqs,
            model_config.max_model_len,
            self.use_v2_block_manager,
604
605
606
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
607
608
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
609
            embedding_mode=model_config.embedding_mode,
610
        )
611
612
613
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
614
            fully_sharded_loras=self.fully_sharded_loras,
615
            lora_extra_vocab_size=self.lora_extra_vocab_size,
616
            long_lora_scaling_factors=self.long_lora_scaling_factors,
617
618
619
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
620

621
622
623
624
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
625
626
        )

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        if self.image_input_type:
            if (not self.image_token_id or not self.image_input_shape
                    or not self.image_feature_size):
                raise ValueError(
                    'Specify `image_token_id`, `image_input_shape` and '
                    '`image_feature_size` together with `image_input_type`.')
            vision_language_config = VisionLanguageConfig(
                image_input_type=VisionLanguageConfig.
                get_image_input_enum_type(self.image_input_type),
                image_token_id=self.image_token_id,
                image_input_shape=str_to_int_tuple(self.image_input_shape),
                image_feature_size=self.image_feature_size,
            )
        else:
            vision_language_config = None

643
644
645
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

646
647
648
649
650
        if (model_config.get_sliding_window() is not None
                and scheduler_config.chunked_prefill_enabled):
            raise ValueError(
                "Chunked prefill is not supported with sliding window.")

651
652
653
654
655
656
657
        return EngineConfig(model_config=model_config,
                            cache_config=cache_config,
                            parallel_config=parallel_config,
                            scheduler_config=scheduler_config,
                            device_config=device_config,
                            lora_config=lora_config,
                            vision_language_config=vision_language_config,
658
                            speculative_config=speculative_config,
659
660
                            load_config=load_config,
                            decoding_config=decoding_config)
661
662


663
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
664
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
665
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
666
    engine_use_ray: bool = False
667
    disable_log_requests: bool = False
668
    max_log_len: Optional[int] = None
669
670

    @staticmethod
671
672
673
674
    def add_cli_args(parser: argparse.ArgumentParser,
                     async_args_only: bool = False) -> argparse.ArgumentParser:
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
675
676
        parser.add_argument('--engine-use-ray',
                            action='store_true',
677
                            help='Use Ray to start the LLM engine in a '
678
679
680
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
681
                            help='Disable logging requests.')
682
683
684
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
685
686
687
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
688
        return parser
689
690
691
692
693
694
695
696
697
698


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
    return EngineArgs.add_cli_args(argparse.ArgumentParser())


def _async_engine_args_parser():
    return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
                                        async_args_only=True)