arg_utils.py 37.5 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import List, Optional, Tuple, Union
6

7
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
8
                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
9
                         MultiModalConfig, ObservabilityConfig, ParallelConfig,
10
11
                         PromptAdapterConfig, SchedulerConfig,
                         SpeculativeConfig, TokenizerPoolConfig)
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
from vllm.utils import FlexibleArgumentParser
14
15


16
17
18
19
20
21
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


22
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
23
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
24
    """Arguments for vLLM engine."""
25
    model: str
26
    served_model_name: Optional[Union[List[str]]] = None
27
    tokenizer: Optional[str] = None
28
    skip_tokenizer_init: bool = False
29
    tokenizer_mode: str = 'auto'
30
    trust_remote_code: bool = False
31
    download_dir: Optional[str] = None
32
    load_format: str = 'auto'
33
    dtype: str = 'auto'
34
    kv_cache_dtype: str = 'auto'
35
    quantization_param_path: Optional[str] = None
36
    seed: int = 0
37
    max_model_len: Optional[int] = None
38
    worker_use_ray: bool = False
39
    distributed_executor_backend: Optional[str] = None
40
41
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
42
    max_parallel_loading_workers: Optional[int] = None
43
    block_size: int = 16
44
    enable_prefix_caching: bool = False
45
    disable_sliding_window: bool = False
46
    use_v2_block_manager: bool = False
47
    swap_space: int = 4  # GiB
48
    gpu_memory_utilization: float = 0.90
49
    max_num_batched_tokens: Optional[int] = None
50
    max_num_seqs: int = 256
51
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
52
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
53
    revision: Optional[str] = None
54
    code_revision: Optional[str] = None
55
    rope_scaling: Optional[dict] = None
56
    rope_theta: Optional[float] = None
57
    tokenizer_revision: Optional[str] = None
58
    quantization: Optional[str] = None
59
    enforce_eager: bool = False
60
61
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
62
    disable_custom_all_reduce: bool = False
63
64
65
    tokenizer_pool_size: int = 0
    tokenizer_pool_type: str = "ray"
    tokenizer_pool_extra_config: Optional[dict] = None
66
67
68
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
69
70
71
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
72
    fully_sharded_loras: bool = False
73
    lora_extra_vocab_size: int = 256
74
    long_lora_scaling_factors: Optional[Tuple[float]] = None
75
    lora_dtype: str = 'auto'
76
    max_cpu_loras: Optional[int] = None
77
    device: str = 'auto'
78
    ray_workers_use_nsight: bool = False
79
    num_gpu_blocks_override: Optional[int] = None
80
    num_lookahead_slots: int = 0
81
    model_loader_extra_config: Optional[dict] = None
82
    preemption_mode: Optional[str] = None
83

84
    scheduler_delay_factor: float = 0.0
85
    enable_chunked_prefill: bool = False
86

87
    guided_decoding_backend: str = 'outlines'
88
89
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
90
    speculative_draft_tensor_parallel_size: Optional[int] = None
91
    num_speculative_tokens: Optional[int] = None
92
    speculative_max_model_len: Optional[int] = None
93
    speculative_disable_by_batch_size: Optional[int] = None
94
95
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
96
97
98
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
99
100
    qlora_adapter_name_or_path: Optional[str] = None

101
102
    otlp_traces_endpoint: Optional[str] = None

103
    def __post_init__(self):
104
105
        if self.tokenizer is None:
            self.tokenizer = self.model
106
107

    @staticmethod
108
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
109
        """Shared CLI arguments for vLLM engine."""
110

111
        # Model arguments
112
113
114
115
        parser.add_argument(
            '--model',
            type=str,
            default='facebook/opt-125m',
116
            help='Name or path of the huggingface model to use.')
117
118
        parser.add_argument(
            '--tokenizer',
119
            type=nullable_str,
120
            default=EngineArgs.tokenizer,
121
122
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
123
124
125
126
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
127
128
        parser.add_argument(
            '--revision',
129
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
130
            default=None,
131
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
132
133
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
134
135
        parser.add_argument(
            '--code-revision',
136
            type=nullable_str,
137
            default=None,
138
            help='The specific revision to use for the model code on '
139
140
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
141
142
        parser.add_argument(
            '--tokenizer-revision',
143
            type=nullable_str,
144
            default=None,
145
146
147
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
148
149
150
151
152
153
154
155
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
            choices=['auto', 'slow'],
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
            'always use the slow tokenizer.')
156
157
        parser.add_argument('--trust-remote-code',
                            action='store_true',
158
                            help='Trust remote code from huggingface.')
159
        parser.add_argument('--download-dir',
160
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
161
                            default=EngineArgs.download_dir,
162
                            help='Directory to download and load the weights, '
163
                            'default to the default cache dir of '
164
                            'huggingface.')
165
166
167
168
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
169
            choices=[
170
171
                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
                'bitsandbytes'
172
            ],
173
174
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
175
            'and fall back to the pytorch bin format if safetensors format '
176
177
178
179
180
181
182
183
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
184
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
185
186
187
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
188
189
190
191
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
195
196
197
198
199
200
201
202
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
203
204
205
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
206
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
207
            default=EngineArgs.kv_cache_dtype,
208
            help='Data type for kv cache storage. If "auto", will use model '
209
210
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
211
212
        parser.add_argument(
            '--quantization-param-path',
213
            type=nullable_str,
214
215
216
217
218
219
220
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
221
            'supported for common inference criteria.')
222
223
        parser.add_argument('--max-model-len',
                            type=int,
224
                            default=EngineArgs.max_model_len,
225
226
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
227
228
229
230
231
232
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
233
234
235
236
237
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
238
        # Parallel arguments
239
240
241
242
243
244
245
246
247
248
249
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
250
251
252
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
253
                            default=EngineArgs.pipeline_parallel_size,
254
                            help='Number of pipeline stages.')
255
256
257
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
258
                            default=EngineArgs.tensor_parallel_size,
259
                            help='Number of tensor parallel replicas.')
260
261
262
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
263
            default=EngineArgs.max_parallel_loading_workers,
264
            help='Load model sequentially in multiple batches, '
265
            'to avoid RAM OOM when using tensor '
266
            'parallel and large models.')
267
268
269
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
270
            help='If specified, use nsight to profile Ray workers.')
271
        # KV cache arguments
272
273
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
274
                            default=EngineArgs.block_size,
275
                            choices=[8, 16, 32],
276
277
                            help='Token block size for contiguous chunks of '
                            'tokens.')
278
279
280

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
281
                            help='Enables automatic prefix caching.')
282
283
284
285
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
286
287
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
288
                            help='Use BlockSpaceMangerV2.')
289
290
291
292
293
294
295
296
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
297

298
299
300
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
301
                            help='Random seed for operations.')
302
303
        parser.add_argument('--swap-space',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
304
                            default=EngineArgs.swap_space,
305
                            help='CPU swap space size (GiB) per GPU.')
306
307
308
309
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
310
311
312
313
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
314
        parser.add_argument(
315
            '--num-gpu-blocks-override',
316
317
318
319
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
320
321
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
322
                            default=EngineArgs.max_num_batched_tokens,
323
324
                            help='Maximum number of batched tokens per '
                            'iteration.')
325
326
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
327
                            default=EngineArgs.max_num_seqs,
328
                            help='Maximum number of sequences per iteration.')
329
330
331
332
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
333
334
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
335
336
        parser.add_argument('--disable-log-stats',
                            action='store_true',
337
                            help='Disable logging statistics.')
338
339
340
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
341
                            type=nullable_str,
342
                            choices=[*QUANTIZATION_METHODS, None],
343
                            default=EngineArgs.quantization,
344
345
346
347
348
349
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
350
351
352
353
354
        parser.add_argument('--rope-scaling',
                            default=None,
                            type=json.loads,
                            help='RoPE scaling configuration in JSON format. '
                            'For example, {"type":"dynamic","factor":2.0}')
355
356
357
358
359
360
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
361
362
363
364
365
366
367
368
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
369
                            help='Maximum context length covered by CUDA '
370
                            'graphs. When a sequence has context length '
371
                            'larger than this, we fall back to eager mode. '
372
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
373
                            ')')
374
        parser.add_argument('--max-seq-len-to-capture',
375
376
377
378
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
379
                            'larger than this, we fall back to eager mode.')
380
381
382
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
383
                            help='See ParallelConfig.')
384
385
386
387
388
389
390
391
392
393
394
395
396
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
397
                            type=nullable_str,
398
399
400
401
402
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
429
430
431
432
433
434
435
436
437
438
439
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
440
441
442
443
444
445
446
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
447
448
449
450
451
452
453
454
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
455
456
457
458
459
460
461
462
463
464
465
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
466
467
468
469
470
471
472
473
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
                            choices=[
                                "auto", "cuda", "neuron", "cpu", "openvino",
                                "tpu", "xpu"
                            ],
                            help='Device type for vLLM execution.')
474

475
476
477
478
479
480
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
            help='Apply a delay (of delay factor multiplied by previous'
            'prompt latency) before scheduling next prompt.')
481
482
        parser.add_argument(
            '--enable-chunked-prefill',
483
484
            action='store_true',
            help='If set, the prefill requests can be chunked based on the '
485
            'max_num_batched_tokens.')
486
487
488

        parser.add_argument(
            '--speculative-model',
489
            type=nullable_str,
490
            default=EngineArgs.speculative_model,
491
492
493
494
495
            help=
            'The name of the draft model to be used in speculative decoding.')
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
496
            default=EngineArgs.num_speculative_tokens,
497
            help='The number of speculative tokens to sample from '
498
            'the draft model in speculative decoding.')
499
500
501
502
503
504
505
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
506

507
508
        parser.add_argument(
            '--speculative-max-model-len',
509
            type=int,
510
511
512
513
514
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

515
516
517
518
519
520
521
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

522
523
524
525
526
527
528
529
530
531
532
533
534
535
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

568
        parser.add_argument('--model-loader-extra-config',
569
                            type=nullable_str,
570
571
572
573
574
575
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
576
        parser.add_argument(
577
            '--preemption-mode',
578
579
580
581
582
            type=str,
            default=None,
            help='If \'recompute\', the engine performs preemption by block '
            'swapping; If \'swap\', the engine performs preemption by block '
            'swapping.')
583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
598
599
600
601
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
602
603
604
605
606
607
608

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')

609
        return parser
610
611

    @classmethod
612
    def from_cli_args(cls, args: argparse.Namespace):
613
614
615
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
616
617
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
618

619
    def create_engine_config(self, ) -> EngineConfig:
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")
636
        multimodal_config = MultiModalConfig()
637

638
        device_config = DeviceConfig(device=self.device)
639
        model_config = ModelConfig(
640
641
642
643
644
645
646
647
648
            model=self.model,
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
649
            rope_theta=self.rope_theta,
650
651
652
653
654
655
656
657
658
659
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_context_len_to_capture=self.max_context_len_to_capture,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
660
            served_model_name=self.served_model_name,
661
            multimodal_config=multimodal_config)
662
663
664
665
666
667
668
669
        cache_config = CacheConfig(
            block_size=self.block_size,
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
            enable_prefix_caching=self.enable_prefix_caching)
670
        parallel_config = ParallelConfig(
671
672
673
674
675
676
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
677
678
679
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
680
            ),
681
            ray_workers_use_nsight=self.ray_workers_use_nsight,
682
            distributed_executor_backend=self.distributed_executor_backend)
683
684
685
686
687
688

        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
689
690
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
691
            num_speculative_tokens=self.num_speculative_tokens,
692
693
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
694
695
696
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
            use_v2_block_manager=self.use_v2_block_manager,
697
698
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
699
700
701
702
703
704
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
705
706
        )

707
        scheduler_config = SchedulerConfig(
708
709
710
711
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
            use_v2_block_manager=self.use_v2_block_manager,
712
713
714
            num_lookahead_slots=(self.num_lookahead_slots
                                 if speculative_config is None else
                                 speculative_config.num_lookahead_slots),
715
716
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
717
            embedding_mode=model_config.embedding_mode,
718
            preemption_mode=self.preemption_mode,
719
        )
720
721
722
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
723
            fully_sharded_loras=self.fully_sharded_loras,
724
            lora_extra_vocab_size=self.lora_extra_vocab_size,
725
            long_lora_scaling_factors=self.long_lora_scaling_factors,
726
727
728
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
729

730
731
732
733
734
735
736
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

737
738
739
740
        load_config = LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
741
742
        )

743
744
745
746
747
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

748
749
750
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

751
752
753
        observability_config = ObservabilityConfig(
            otlp_traces_endpoint=self.otlp_traces_endpoint)

754
        if (model_config.get_sliding_window() is not None
755
756
                and scheduler_config.chunked_prefill_enabled
                and not scheduler_config.use_v2_block_manager):
757
            raise ValueError(
758
759
                "Chunked prefill is not supported with sliding window. "
                "Set --disable-sliding-window to disable sliding window.")
760

761
762
763
764
765
766
767
        return EngineConfig(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
768
            multimodal_config=multimodal_config,
769
770
771
772
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
773
            prompt_adapter_config=prompt_adapter_config,
774
        )
775
776


777
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
778
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
779
    """Arguments for asynchronous vLLM engine."""
Zhuohan Li's avatar
Zhuohan Li committed
780
    engine_use_ray: bool = False
781
    disable_log_requests: bool = False
782
    max_log_len: Optional[int] = None
783
784

    @staticmethod
785
786
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
787
788
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
789
790
        parser.add_argument('--engine-use-ray',
                            action='store_true',
791
                            help='Use Ray to start the LLM engine in a '
792
793
794
                            'separate process as the server process.')
        parser.add_argument('--disable-log-requests',
                            action='store_true',
795
                            help='Disable logging requests.')
796
797
798
        parser.add_argument('--max-log-len',
                            type=int,
                            default=None,
799
800
801
                            help='Max number of prompt characters or prompt '
                            'ID numbers being printed in log.'
                            '\n\nDefault: Unlimited')
802
        return parser
803
804
805
806


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
807
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
808
809
810


def _async_engine_args_parser():
811
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
812
                                        async_args_only=True)