arg_utils.py 52.1 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
6
                    Tuple, Type, Union, cast)
7

8
9
import torch

10
import vllm.envs as envs
11
12
13
14
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
                         DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
                         LoRAConfig, ModelConfig, ObservabilityConfig,
                         ParallelConfig, PromptAdapterConfig, SchedulerConfig,
15
                         SpeculativeConfig, TokenizerPoolConfig)
16
from vllm.executor.executor_base import ExecutorBase
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
19
from vllm.transformers_utils.utils import check_gguf_file
20
from vllm.utils import FlexibleArgumentParser
21

22
if TYPE_CHECKING:
23
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
24

25
26
logger = init_logger(__name__)

27
28
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

29
30
31
32
33
34
35
36
37
38
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
]

39

40
41
42
43
44
45
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


46
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
47
48
49
50
51
52
53
54
55
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
56
57
58
59
60
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
61
62
63
64
65
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
66
67

        try:
68
            parsed_value = int(value)
69
70
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
71
72
73
74
75
76
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
77
78
79
80

    return out_dict


81
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
82
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
83
    """Arguments for vLLM engine."""
84
    model: str = 'facebook/opt-125m'
85
    served_model_name: Optional[Union[str, List[str]]] = None
86
    tokenizer: Optional[str] = None
87
    skip_tokenizer_init: bool = False
88
    tokenizer_mode: str = 'auto'
89
    trust_remote_code: bool = False
90
    download_dir: Optional[str] = None
91
    load_format: str = 'auto'
92
    config_format: ConfigFormat = ConfigFormat.AUTO
93
    dtype: str = 'auto'
94
    kv_cache_dtype: str = 'auto'
95
    quantization_param_path: Optional[str] = None
96
    seed: int = 0
97
    max_model_len: Optional[int] = None
98
    worker_use_ray: bool = False
99
100
101
102
103
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
104
105
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
106
    max_parallel_loading_workers: Optional[int] = None
107
    block_size: int = 16
108
    enable_prefix_caching: bool = False
109
    disable_sliding_window: bool = False
110
    use_v2_block_manager: bool = True
111
112
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
113
    gpu_memory_utilization: float = 0.90
114
    max_num_batched_tokens: Optional[int] = None
115
    max_num_seqs: int = 256
116
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
117
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
118
    revision: Optional[str] = None
119
    code_revision: Optional[str] = None
120
    rope_scaling: Optional[dict] = None
121
    rope_theta: Optional[float] = None
122
    tokenizer_revision: Optional[str] = None
123
    quantization: Optional[str] = None
124
    enforce_eager: Optional[bool] = None
125
126
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
127
    disable_custom_all_reduce: bool = False
128
    tokenizer_pool_size: int = 0
129
130
131
132
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
133
    tokenizer_pool_extra_config: Optional[dict] = None
134
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
135
136
137
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
138
139
140
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
141
    fully_sharded_loras: bool = False
142
    lora_extra_vocab_size: int = 256
143
    long_lora_scaling_factors: Optional[Tuple[float]] = None
144
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
145
    max_cpu_loras: Optional[int] = None
146
    device: str = 'auto'
147
    num_scheduler_steps: int = 1
148
    multi_step_stream_outputs: bool = True
149
    ray_workers_use_nsight: bool = False
150
    num_gpu_blocks_override: Optional[int] = None
151
    num_lookahead_slots: int = 0
152
    model_loader_extra_config: Optional[dict] = None
153
    ignore_patterns: Optional[Union[str, List[str]]] = None
154
    preemption_mode: Optional[str] = None
155

156
    scheduler_delay_factor: float = 0.0
157
    enable_chunked_prefill: Optional[bool] = None
158

159
    guided_decoding_backend: str = 'outlines'
160
161
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
162
    speculative_model_quantization: Optional[str] = None
163
    speculative_draft_tensor_parallel_size: Optional[int] = None
164
    num_speculative_tokens: Optional[int] = None
165
    speculative_disable_mqa_scorer: Optional[bool] = False
166
    speculative_max_model_len: Optional[int] = None
167
    speculative_disable_by_batch_size: Optional[int] = None
168
169
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
170
171
172
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
173
    qlora_adapter_name_or_path: Optional[str] = None
174
    disable_logprobs_during_spec_decoding: Optional[bool] = None
175

176
    otlp_traces_endpoint: Optional[str] = None
177
    collect_detailed_traces: Optional[str] = None
178
    disable_async_output_proc: bool = False
179
    override_neuron_config: Optional[Dict[str, Any]] = None
180
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
181
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
182

183
    def __post_init__(self):
184
        if not self.tokenizer:
185
            self.tokenizer = self.model
186
187

        # Setup plugins
188
189
        from vllm.plugins import load_general_plugins
        load_general_plugins()
190
191

    @staticmethod
192
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
193
        """Shared CLI arguments for vLLM engine."""
194

195
        # Model arguments
196
197
198
        parser.add_argument(
            '--model',
            type=str,
199
            default=EngineArgs.model,
200
            help='Name or path of the huggingface model to use.')
201
202
        parser.add_argument(
            '--tokenizer',
203
            type=nullable_str,
204
            default=EngineArgs.tokenizer,
205
206
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
207
208
209
210
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
211
212
        parser.add_argument(
            '--revision',
213
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
214
            default=None,
215
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
216
217
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
218
219
        parser.add_argument(
            '--code-revision',
220
            type=nullable_str,
221
            default=None,
222
            help='The specific revision to use for the model code on '
223
224
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
225
226
        parser.add_argument(
            '--tokenizer-revision',
227
            type=nullable_str,
228
            default=None,
229
230
231
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
232
233
234
235
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
236
            choices=['auto', 'slow', 'mistral'],
237
238
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
239
240
            'always use the slow tokenizer. \n* '
            '"mistral" will always use the `mistral_common` tokenizer.')
241
242
        parser.add_argument('--trust-remote-code',
                            action='store_true',
243
                            help='Trust remote code from huggingface.')
244
        parser.add_argument('--download-dir',
245
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
246
                            default=EngineArgs.download_dir,
247
                            help='Directory to download and load the weights, '
248
                            'default to the default cache dir of '
249
                            'huggingface.')
250
251
252
253
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
254
            choices=[f.value for f in LoadFormat],
255
256
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
257
            'and fall back to the pytorch bin format if safetensors format '
258
259
260
261
262
263
264
265
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
266
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
267
268
269
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
270
271
272
273
274
275
276
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
277
278
279
280
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
281
282
283
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
284
285
286
287
288
289
290
291
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
292
293
294
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
295
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
296
            default=EngineArgs.kv_cache_dtype,
297
            help='Data type for kv cache storage. If "auto", will use model '
298
299
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
300
301
        parser.add_argument(
            '--quantization-param-path',
302
            type=nullable_str,
303
304
305
306
307
308
309
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
310
            'supported for common inference criteria.')
311
312
        parser.add_argument('--max-model-len',
                            type=int,
313
                            default=EngineArgs.max_model_len,
314
315
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
316
317
318
319
320
321
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
322
323
324
325
326
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
327
        # Parallel arguments
328
329
330
331
332
333
334
335
336
337
338
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
339
340
341
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
342
                            default=EngineArgs.pipeline_parallel_size,
343
                            help='Number of pipeline stages.')
344
345
346
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
347
                            default=EngineArgs.tensor_parallel_size,
348
                            help='Number of tensor parallel replicas.')
349
350
351
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
352
            default=EngineArgs.max_parallel_loading_workers,
353
            help='Load model sequentially in multiple batches, '
354
            'to avoid RAM OOM when using tensor '
355
            'parallel and large models.')
356
357
358
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
359
            help='If specified, use nsight to profile Ray workers.')
360
        # KV cache arguments
361
362
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
363
                            default=EngineArgs.block_size,
364
                            choices=[8, 16, 32],
365
                            help='Token block size for contiguous chunks of '
366
367
                            'tokens. This is ignored on neuron devices and '
                            'set to max-model-len')
368
369
370

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
371
                            help='Enables automatic prefix caching.')
372
373
374
375
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
376
377
378
379
380
381
382
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
383
384
385
386
387
388
389
390
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
391

392
393
394
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
395
                            help='Random seed for operations.')
396
        parser.add_argument('--swap-space',
397
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
398
                            default=EngineArgs.swap_space,
399
                            help='CPU swap space size (GiB) per GPU.')
400
401
402
403
404
405
406
407
408
409
410
411
412
413
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
            'which requires at least 26GB GPU memory. Note that this '
            'requires fast CPU-GPU interconnect, as part of the model is'
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
414
415
416
417
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
418
419
420
421
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
422
        parser.add_argument(
423
            '--num-gpu-blocks-override',
424
425
426
427
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
428
429
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
430
                            default=EngineArgs.max_num_batched_tokens,
431
432
                            help='Maximum number of batched tokens per '
                            'iteration.')
433
434
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
435
                            default=EngineArgs.max_num_seqs,
436
                            help='Maximum number of sequences per iteration.')
437
438
439
440
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
441
442
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
443
444
        parser.add_argument('--disable-log-stats',
                            action='store_true',
445
                            help='Disable logging statistics.')
446
447
448
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
449
                            type=nullable_str,
450
                            choices=[*QUANTIZATION_METHODS, None],
451
                            default=EngineArgs.quantization,
452
453
454
455
456
457
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
458
459
460
461
462
463
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
            'For example, {"rope_type":"dynamic","factor":2.0}')
464
465
466
467
468
469
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
470
471
472
473
474
475
476
477
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
478
                            help='Maximum context length covered by CUDA '
479
                            'graphs. When a sequence has context length '
480
                            'larger than this, we fall back to eager mode. '
481
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
482
                            ')')
483
        parser.add_argument('--max-seq-len-to-capture',
484
485
486
487
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
488
489
490
491
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
492
493
494
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
495
                            help='See ParallelConfig.')
496
497
498
499
500
501
502
503
504
505
506
507
508
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
509
                            type=nullable_str,
510
511
512
513
514
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
515
516
517
518
519
520
521
522
523
524
525
526
527
528

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
529
530
531
532
533
534
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
            help=('Overrides for the multimodal input mapping/processing,'
                  'e.g., image processor. For example: {"num_crops": 4}.'))
535

536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
562
563
564
565
566
567
568
569
570
571
572
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
573
574
575
576
577
578
579
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
580
581
582
583
584
585
586
587
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
588
589
590
591
592
593
594
595
596
597
598
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
599
600
601
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
602
                            choices=DEVICE_OPTIONS,
603
                            help='Device type for vLLM execution.')
604
605
606
607
608
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
609

610
611
        parser.add_argument(
            '--multi-step-stream-outputs',
612
613
614
615
616
617
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
618
619
620
621
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
622
            help='Apply a delay (of delay factor multiplied by previous '
623
            'prompt latency) before scheduling next prompt.')
624
625
        parser.add_argument(
            '--enable-chunked-prefill',
626
627
628
629
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
630
            help='If set, the prefill requests can be chunked based on the '
631
            'max_num_batched_tokens.')
632
633
634

        parser.add_argument(
            '--speculative-model',
635
            type=nullable_str,
636
            default=EngineArgs.speculative_model,
637
638
            help=
            'The name of the draft model to be used in speculative decoding.')
639
640
641
642
643
644
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
645
            help='Method used to quantize the weights of speculative model. '
646
647
648
649
650
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
651
652
653
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
654
            default=EngineArgs.num_speculative_tokens,
655
            help='The number of speculative tokens to sample from '
656
            'the draft model in speculative decoding.')
657
658
659
660
661
662
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
663
664
665
666
667
668
669
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
670

671
672
        parser.add_argument(
            '--speculative-max-model-len',
673
            type=int,
674
675
676
677
678
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

679
680
681
682
683
684
685
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

686
687
688
689
690
691
692
693
694
695
696
697
698
699
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

732
733
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
734
            action=StoreBoolean,
735
            default=EngineArgs.disable_logprobs_during_spec_decoding,
736
737
            nargs="?",
            const="True",
738
739
740
741
742
743
744
745
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

746
        parser.add_argument('--model-loader-extra-config',
747
                            type=nullable_str,
748
749
750
751
752
753
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
754
755
756
757
758
759
760
761
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
            "Default to 'original/**/*' to avoid repeated loading of llama's "
            "checkpoints.")
762
        parser.add_argument(
763
            '--preemption-mode',
764
765
            type=str,
            default=None,
766
767
768
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
769

770
771
772
773
774
775
776
777
778
779
780
781
782
783
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
784
785
786
787
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
788
789
790
791
792
793

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
794
795
796
797
798
799
800
801
802
803
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
            ". It makes sense to set this only if --otlp-traces-endpoint is"
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
804

805
806
807
808
809
810
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
811
812
        parser.add_argument(
            '--override-neuron-config',
813
            type=json.loads,
814
            default=None,
815
816
            help="Override or set neuron device configuration. "
            "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'")
817

818
819
820
821
822
823
824
825
826
827
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

828
        return parser
829
830

    @classmethod
831
    def from_cli_args(cls, args: argparse.Namespace):
832
833
834
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
835
836
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
837

838
839
    def create_model_config(self) -> ModelConfig:
        return ModelConfig(
840
            model=self.model,
841
842
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
843
844
845
846
847
848
849
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
850
            rope_theta=self.rope_theta,
851
852
853
854
855
856
857
858
859
860
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_context_len_to_capture=self.max_context_len_to_capture,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
861
            served_model_name=self.served_model_name,
862
            limit_mm_per_prompt=self.limit_mm_per_prompt,
863
            use_async_output_proc=not self.disable_async_output_proc,
864
865
            override_neuron_config=self.override_neuron_config,
            config_format=self.config_format,
866
            mm_processor_kwargs=self.mm_processor_kwargs,
867
868
        )

869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
    def create_load_config(self) -> LoadConfig:
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
        )

    def create_engine_config(self) -> EngineConfig:
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

        device_config = DeviceConfig(device=self.device)
        model_config = self.create_model_config()

905
906
907
908
909
910
911
        if model_config.is_multimodal_model:
            if self.enable_prefix_caching:
                logger.warning(
                    "--enable-prefix-caching is currently not "
                    "supported for multimodal models and has been disabled.")
            self.enable_prefix_caching = False

912
        cache_config = CacheConfig(
913
            # neuron needs block_size = max_model_len
914
            block_size=self.block_size if self.device != "neuron" else
915
            (self.max_model_len if self.max_model_len is not None else 0),
916
917
918
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
919
            is_attention_free=model_config.is_attention_free,
920
921
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
922
923
924
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
        )
925
        parallel_config = ParallelConfig(
926
927
928
929
930
931
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
932
933
934
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
935
            ),
936
            ray_workers_use_nsight=self.ray_workers_use_nsight,
937
            distributed_executor_backend=self.distributed_executor_backend)
938

939
940
941
942
943
944
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
945
946
947
948

            # Chunked prefill is currently disabled for multimodal models by
            # default.
            if use_long_context and not model_config.is_multimodal_model:
949
950
951
952
953
954
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
955
                        and not self.enable_prompt_adapter):
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)

973
974
975
976
977
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
978
979
            speculative_model_quantization = \
                self.speculative_model_quantization,
980
981
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
982
            num_speculative_tokens=self.num_speculative_tokens,
983
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
984
985
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
986
987
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
988
            disable_log_stats=self.disable_log_stats,
989
990
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
991
992
993
994
995
996
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
997
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
998
999
        )

1000
1001
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
1002
1003
1004
1005
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1006
1007
1008
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1009
1010
1011
1012
1013
1014
1015
1016
1017

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
        if not self.use_v2_block_manager:
            logger.warning(
                "[DEPRECATED] Block manager v1 has been removed, "
                "and setting --use-v2-block-manager to True or False has "
                "no effect on vLLM behavior. Please remove "
                "--use-v2-block-manager in your engine argument. "
                "If your use case is not supported by "
                "SelfAttnBlockSpaceManager (i.e. block manager v2),"
                " please file an issue with detailed information.")

1028
        scheduler_config = SchedulerConfig(
1029
1030
1031
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1032
            num_lookahead_slots=num_lookahead_slots,
1033
1034
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1035
            embedding_mode=model_config.embedding_mode,
1036
            is_multimodal_model=model_config.is_multimodal_model,
1037
            preemption_mode=self.preemption_mode,
1038
            num_scheduler_steps=self.num_scheduler_steps,
1039
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1040
1041
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1042
            policy=self.scheduling_policy,
1043
        )
1044
1045
1046
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1047
            fully_sharded_loras=self.fully_sharded_loras,
1048
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1049
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1050
1051
1052
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1053

1054
1055
1056
1057
1058
1059
1060
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1061
        load_config = self.create_load_config()
1062

1063
1064
1065
1066
1067
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1068
1069
1070
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

1071
1072
1073
1074
1075
1076
1077
1078
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1079
        observability_config = ObservabilityConfig(
1080
1081
1082
1083
1084
1085
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097

        return EngineConfig(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1098
            prompt_adapter_config=prompt_adapter_config,
1099
        )
1100
1101


1102
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1103
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1104
    """Arguments for asynchronous vLLM engine."""
1105
    disable_log_requests: bool = False
1106
1107

    @staticmethod
1108
1109
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1110
1111
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1112
1113
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1114
                            help='Disable logging requests.')
1115
        return parser
1116
1117


1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
class StoreBoolean(argparse.Action):

    def __call__(self, parser, namespace, values, option_string=None):
        if values.lower() == "true":
            setattr(namespace, self.dest, True)
        elif values.lower() == "false":
            setattr(namespace, self.dest, False)
        else:
            raise ValueError(f"Invalid boolean value: {values}. "
                             "Expected 'true' or 'false'.")


1130
1131
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1132
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1133
1134
1135


def _async_engine_args_parser():
1136
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1137
                                        async_args_only=True)