arg_utils.py 52.6 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
6
                    Tuple, Type, Union, cast, get_args)
7

8
9
import torch

10
import vllm.envs as envs
11
12
13
14
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
                         DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
                         LoRAConfig, ModelConfig, ObservabilityConfig,
                         ParallelConfig, PromptAdapterConfig, SchedulerConfig,
15
                         SpeculativeConfig, TaskOption, TokenizerPoolConfig)
16
from vllm.executor.executor_base import ExecutorBase
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
19
from vllm.transformers_utils.utils import check_gguf_file
20
from vllm.utils import FlexibleArgumentParser
21

22
if TYPE_CHECKING:
23
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
24

25
26
logger = init_logger(__name__)

27
28
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

29
30
31
32
33
34
35
36
37
38
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
]

39

40
41
42
43
44
45
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


46
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
47
48
49
50
51
52
53
54
55
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
56
57
58
59
60
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
61
62
63
64
65
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
66
67

        try:
68
            parsed_value = int(value)
69
70
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
71
72
73
74
75
76
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
77
78
79
80

    return out_dict


81
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
82
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
83
    """Arguments for vLLM engine."""
84
    model: str = 'facebook/opt-125m'
85
    served_model_name: Optional[Union[str, List[str]]] = None
86
    tokenizer: Optional[str] = None
87
    task: TaskOption = "auto"
88
    skip_tokenizer_init: bool = False
89
    tokenizer_mode: str = 'auto'
90
    trust_remote_code: bool = False
91
    download_dir: Optional[str] = None
92
    load_format: str = 'auto'
93
    config_format: ConfigFormat = ConfigFormat.AUTO
94
    dtype: str = 'auto'
95
    kv_cache_dtype: str = 'auto'
96
    quantization_param_path: Optional[str] = None
97
    seed: int = 0
98
    max_model_len: Optional[int] = None
99
    worker_use_ray: bool = False
100
101
102
103
104
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
105
106
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
107
    max_parallel_loading_workers: Optional[int] = None
108
    block_size: int = 16
109
    enable_prefix_caching: bool = False
110
    disable_sliding_window: bool = False
111
    use_v2_block_manager: bool = True
112
113
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
114
    gpu_memory_utilization: float = 0.90
115
    max_num_batched_tokens: Optional[int] = None
116
    max_num_seqs: int = 256
117
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
118
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
119
    revision: Optional[str] = None
120
    code_revision: Optional[str] = None
121
    rope_scaling: Optional[dict] = None
122
    rope_theta: Optional[float] = None
123
    tokenizer_revision: Optional[str] = None
124
    quantization: Optional[str] = None
125
    enforce_eager: Optional[bool] = None
126
127
    max_context_len_to_capture: Optional[int] = None
    max_seq_len_to_capture: int = 8192
128
    disable_custom_all_reduce: bool = False
129
    tokenizer_pool_size: int = 0
130
131
132
133
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
134
    tokenizer_pool_extra_config: Optional[dict] = None
135
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
136
137
138
    enable_lora: bool = False
    max_loras: int = 1
    max_lora_rank: int = 16
139
140
141
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
142
    fully_sharded_loras: bool = False
143
    lora_extra_vocab_size: int = 256
144
    long_lora_scaling_factors: Optional[Tuple[float]] = None
145
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
146
    max_cpu_loras: Optional[int] = None
147
    device: str = 'auto'
148
    num_scheduler_steps: int = 1
149
    multi_step_stream_outputs: bool = True
150
    ray_workers_use_nsight: bool = False
151
    num_gpu_blocks_override: Optional[int] = None
152
    num_lookahead_slots: int = 0
153
    model_loader_extra_config: Optional[dict] = None
154
    ignore_patterns: Optional[Union[str, List[str]]] = None
155
    preemption_mode: Optional[str] = None
156

157
    scheduler_delay_factor: float = 0.0
158
    enable_chunked_prefill: Optional[bool] = None
159

160
    guided_decoding_backend: str = 'outlines'
161
162
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
163
    speculative_model_quantization: Optional[str] = None
164
    speculative_draft_tensor_parallel_size: Optional[int] = None
165
    num_speculative_tokens: Optional[int] = None
166
    speculative_disable_mqa_scorer: Optional[bool] = False
167
    speculative_max_model_len: Optional[int] = None
168
    speculative_disable_by_batch_size: Optional[int] = None
169
170
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
171
172
173
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
174
    qlora_adapter_name_or_path: Optional[str] = None
175
    disable_logprobs_during_spec_decoding: Optional[bool] = None
176

177
    otlp_traces_endpoint: Optional[str] = None
178
    collect_detailed_traces: Optional[str] = None
179
    disable_async_output_proc: bool = False
180
    override_neuron_config: Optional[Dict[str, Any]] = None
181
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
182
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
183

184
    def __post_init__(self):
185
        if not self.tokenizer:
186
            self.tokenizer = self.model
187
188

        # Setup plugins
189
190
        from vllm.plugins import load_general_plugins
        load_general_plugins()
191
192

    @staticmethod
193
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
194
        """Shared CLI arguments for vLLM engine."""
195

196
        # Model arguments
197
198
199
        parser.add_argument(
            '--model',
            type=str,
200
            default=EngineArgs.model,
201
            help='Name or path of the huggingface model to use.')
202
203
204
205
206
207
208
209
210
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
            'multiple tasks. When the model only supports one task, "auto" '
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
211
212
        parser.add_argument(
            '--tokenizer',
213
            type=nullable_str,
214
            default=EngineArgs.tokenizer,
215
216
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
217
218
219
220
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
221
222
        parser.add_argument(
            '--revision',
223
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
224
            default=None,
225
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
226
227
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
228
229
        parser.add_argument(
            '--code-revision',
230
            type=nullable_str,
231
            default=None,
232
            help='The specific revision to use for the model code on '
233
234
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
235
236
        parser.add_argument(
            '--tokenizer-revision',
237
            type=nullable_str,
238
            default=None,
239
240
241
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
242
243
244
245
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
246
            choices=['auto', 'slow', 'mistral'],
247
248
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
249
250
            'always use the slow tokenizer. \n* '
            '"mistral" will always use the `mistral_common` tokenizer.')
251
252
        parser.add_argument('--trust-remote-code',
                            action='store_true',
253
                            help='Trust remote code from huggingface.')
254
        parser.add_argument('--download-dir',
255
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
256
                            default=EngineArgs.download_dir,
257
                            help='Directory to download and load the weights, '
258
                            'default to the default cache dir of '
259
                            'huggingface.')
260
261
262
263
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
264
            choices=[f.value for f in LoadFormat],
265
266
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
267
            'and fall back to the pytorch bin format if safetensors format '
268
269
270
271
272
273
274
275
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
276
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
277
278
279
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
280
281
282
283
284
285
286
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
287
288
289
290
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292
293
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
294
295
296
297
298
299
300
301
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
302
303
304
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
305
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
306
            default=EngineArgs.kv_cache_dtype,
307
            help='Data type for kv cache storage. If "auto", will use model '
308
309
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
310
311
        parser.add_argument(
            '--quantization-param-path',
312
            type=nullable_str,
313
314
315
316
317
318
319
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
            'FP8_E5M2 (without scaling) is only supported on cuda version'
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
320
            'supported for common inference criteria.')
321
322
        parser.add_argument('--max-model-len',
                            type=int,
323
                            default=EngineArgs.max_model_len,
324
325
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
326
327
328
329
330
331
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
332
333
334
335
336
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
337
        # Parallel arguments
338
339
340
341
342
343
344
345
346
347
348
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
            help='Backend to use for distributed serving. When more than 1 GPU '
            'is used, will be automatically set to "ray" if installed '
            'or "mp" (multiprocessing) otherwise.')
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
349
350
351
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
352
                            default=EngineArgs.pipeline_parallel_size,
353
                            help='Number of pipeline stages.')
354
355
356
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
357
                            default=EngineArgs.tensor_parallel_size,
358
                            help='Number of tensor parallel replicas.')
359
360
361
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
362
            default=EngineArgs.max_parallel_loading_workers,
363
            help='Load model sequentially in multiple batches, '
364
            'to avoid RAM OOM when using tensor '
365
            'parallel and large models.')
366
367
368
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
369
            help='If specified, use nsight to profile Ray workers.')
370
        # KV cache arguments
371
372
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
373
                            default=EngineArgs.block_size,
374
                            choices=[8, 16, 32],
375
                            help='Token block size for contiguous chunks of '
376
377
                            'tokens. This is ignored on neuron devices and '
                            'set to max-model-len')
378
379
380

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
381
                            help='Enables automatic prefix caching.')
382
383
384
385
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
386
387
388
389
390
391
392
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
393
394
395
396
397
398
399
400
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
401

402
403
404
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
405
                            help='Random seed for operations.')
406
        parser.add_argument('--swap-space',
407
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
408
                            default=EngineArgs.swap_space,
409
                            help='CPU swap space size (GiB) per GPU.')
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
            'which requires at least 26GB GPU memory. Note that this '
            'requires fast CPU-GPU interconnect, as part of the model is'
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
424
425
426
427
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
428
429
430
431
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
            'will use the default value of 0.9.')
432
        parser.add_argument(
433
            '--num-gpu-blocks-override',
434
435
436
437
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
            'of GPU blocks. Used for testing preemption.')
438
439
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
440
                            default=EngineArgs.max_num_batched_tokens,
441
442
                            help='Maximum number of batched tokens per '
                            'iteration.')
443
444
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
445
                            default=EngineArgs.max_num_seqs,
446
                            help='Maximum number of sequences per iteration.')
447
448
449
450
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
451
452
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
453
454
        parser.add_argument('--disable-log-stats',
                            action='store_true',
455
                            help='Disable logging statistics.')
456
457
458
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
459
                            type=nullable_str,
460
                            choices=[*QUANTIZATION_METHODS, None],
461
                            default=EngineArgs.quantization,
462
463
464
465
466
467
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
468
469
470
471
472
473
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
            'For example, {"rope_type":"dynamic","factor":2.0}')
474
475
476
477
478
479
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
480
481
482
483
484
485
486
487
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
        parser.add_argument('--max-context-len-to-capture',
                            type=int,
                            default=EngineArgs.max_context_len_to_capture,
488
                            help='Maximum context length covered by CUDA '
489
                            'graphs. When a sequence has context length '
490
                            'larger than this, we fall back to eager mode. '
491
                            '(DEPRECATED. Use --max-seq-len-to-capture instead'
492
                            ')')
493
        parser.add_argument('--max-seq-len-to-capture',
494
495
496
497
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
498
499
500
501
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
502
503
504
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
505
                            help='See ParallelConfig.')
506
507
508
509
510
511
512
513
514
515
516
517
518
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
519
                            type=nullable_str,
520
521
522
523
524
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
525
526
527
528
529
530
531
532
533
534
535
536
537
538

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
539
540
541
542
543
544
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
            help=('Overrides for the multimodal input mapping/processing,'
                  'e.g., image processor. For example: {"num_crops": 4}.'))
545

546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
            choices=['auto', 'float16', 'bfloat16', 'float32'],
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
572
573
574
575
576
577
578
579
580
581
582
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
583
584
585
586
587
588
589
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
                  'Must be >= than max_num_seqs. '
                  'Defaults to max_num_seqs.'))
590
591
592
593
594
595
596
597
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
598
599
600
601
602
603
604
605
606
607
608
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
609
610
611
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
612
                            choices=DEVICE_OPTIONS,
613
                            help='Device type for vLLM execution.')
614
615
616
617
618
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
619

620
621
        parser.add_argument(
            '--multi-step-stream-outputs',
622
623
624
625
626
627
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
628
629
630
631
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
632
            help='Apply a delay (of delay factor multiplied by previous '
633
            'prompt latency) before scheduling next prompt.')
634
635
        parser.add_argument(
            '--enable-chunked-prefill',
636
637
638
639
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
640
            help='If set, the prefill requests can be chunked based on the '
641
            'max_num_batched_tokens.')
642
643
644

        parser.add_argument(
            '--speculative-model',
645
            type=nullable_str,
646
            default=EngineArgs.speculative_model,
647
648
            help=
            'The name of the draft model to be used in speculative decoding.')
649
650
651
652
653
654
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
655
            help='Method used to quantize the weights of speculative model. '
656
657
658
659
660
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
661
662
663
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
664
            default=EngineArgs.num_speculative_tokens,
665
            help='The number of speculative tokens to sample from '
666
            'the draft model in speculative decoding.')
667
668
669
670
671
672
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
673
674
675
676
677
678
679
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
680

681
682
        parser.add_argument(
            '--speculative-max-model-len',
683
            type=int,
684
685
686
687
688
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

689
690
691
692
693
694
695
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

696
697
698
699
700
701
702
703
704
705
706
707
708
709
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

742
743
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
744
            action=StoreBoolean,
745
            default=EngineArgs.disable_logprobs_during_spec_decoding,
746
747
            nargs="?",
            const="True",
748
749
750
751
752
753
754
755
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

756
        parser.add_argument('--model-loader-extra-config',
757
                            type=nullable_str,
758
759
760
761
762
763
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
764
765
766
767
768
769
770
771
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
            "Default to 'original/**/*' to avoid repeated loading of llama's "
            "checkpoints.")
772
        parser.add_argument(
773
            '--preemption-mode',
774
775
            type=str,
            default=None,
776
777
778
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
779

780
781
782
783
784
785
786
787
788
789
790
791
792
793
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
            "same as the `--model` argument. Noted that this name(s)"
            "will also be used in `model_name` tag content of "
            "prometheus metrics, if multiple names provided, metrics"
            "tag will take the first one.")
794
795
796
797
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
798
799
800
801
802
803

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
804
805
806
807
808
809
810
811
812
813
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
            ". It makes sense to set this only if --otlp-traces-endpoint is"
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
814

815
816
817
818
819
820
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
821
822
        parser.add_argument(
            '--override-neuron-config',
823
            type=json.loads,
824
            default=None,
825
826
            help="Override or set neuron device configuration. "
            "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'")
827

828
829
830
831
832
833
834
835
836
837
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

838
        return parser
839
840

    @classmethod
841
    def from_cli_args(cls, args: argparse.Namespace):
842
843
844
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
845
846
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
847

848
849
    def create_model_config(self) -> ModelConfig:
        return ModelConfig(
850
            model=self.model,
851
            task=self.task,
852
853
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
854
855
856
857
858
859
860
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
861
            rope_theta=self.rope_theta,
862
863
864
865
866
867
868
869
870
871
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_context_len_to_capture=self.max_context_len_to_capture,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
872
            served_model_name=self.served_model_name,
873
            limit_mm_per_prompt=self.limit_mm_per_prompt,
874
            use_async_output_proc=not self.disable_async_output_proc,
875
876
            override_neuron_config=self.override_neuron_config,
            config_format=self.config_format,
877
            mm_processor_kwargs=self.mm_processor_kwargs,
878
879
        )

880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
    def create_load_config(self) -> LoadConfig:
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
        )

    def create_engine_config(self) -> EngineConfig:
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

        device_config = DeviceConfig(device=self.device)
        model_config = self.create_model_config()

916
917
918
919
920
921
922
        if model_config.is_multimodal_model:
            if self.enable_prefix_caching:
                logger.warning(
                    "--enable-prefix-caching is currently not "
                    "supported for multimodal models and has been disabled.")
            self.enable_prefix_caching = False

923
        cache_config = CacheConfig(
924
            # neuron needs block_size = max_model_len
925
            block_size=self.block_size if self.device != "neuron" else
926
            (self.max_model_len if self.max_model_len is not None else 0),
927
928
929
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
930
            is_attention_free=model_config.is_attention_free,
931
932
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
933
934
935
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
        )
936
        parallel_config = ParallelConfig(
937
938
939
940
941
942
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
943
944
945
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
946
            ),
947
            ray_workers_use_nsight=self.ray_workers_use_nsight,
948
            distributed_executor_backend=self.distributed_executor_backend)
949

950
951
952
953
954
955
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
956
957
958
959

            # Chunked prefill is currently disabled for multimodal models by
            # default.
            if use_long_context and not model_config.is_multimodal_model:
960
961
962
963
964
965
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
966
                        and not self.enable_prompt_adapter):
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)

984
985
986
987
988
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
989
990
            speculative_model_quantization = \
                self.speculative_model_quantization,
991
992
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
993
            num_speculative_tokens=self.num_speculative_tokens,
994
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
995
996
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
997
998
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
999
            disable_log_stats=self.disable_log_stats,
1000
1001
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1002
1003
1004
1005
1006
1007
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
1008
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
1009
1010
        )

1011
1012
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
1013
1014
1015
1016
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1017
1018
1019
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1020
1021
1022
1023
1024
1025
1026
1027
1028

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        if not self.use_v2_block_manager:
            logger.warning(
                "[DEPRECATED] Block manager v1 has been removed, "
                "and setting --use-v2-block-manager to True or False has "
                "no effect on vLLM behavior. Please remove "
                "--use-v2-block-manager in your engine argument. "
                "If your use case is not supported by "
                "SelfAttnBlockSpaceManager (i.e. block manager v2),"
                " please file an issue with detailed information.")

1039
        scheduler_config = SchedulerConfig(
1040
            task=model_config.task,
1041
1042
1043
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1044
            num_lookahead_slots=num_lookahead_slots,
1045
1046
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1047
            is_multimodal_model=model_config.is_multimodal_model,
1048
            preemption_mode=self.preemption_mode,
1049
            num_scheduler_steps=self.num_scheduler_steps,
1050
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1051
1052
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1053
            policy=self.scheduling_policy,
1054
        )
1055
1056
1057
        lora_config = LoRAConfig(
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1058
            fully_sharded_loras=self.fully_sharded_loras,
1059
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1060
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1061
1062
1063
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1064

1065
1066
1067
1068
1069
1070
1071
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1072
        load_config = self.create_load_config()
1073

1074
1075
1076
1077
1078
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1079
1080
1081
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

1082
1083
1084
1085
1086
1087
1088
1089
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1090
        observability_config = ObservabilityConfig(
1091
1092
1093
1094
1095
1096
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108

        return EngineConfig(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1109
            prompt_adapter_config=prompt_adapter_config,
1110
        )
1111
1112


1113
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1114
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1115
    """Arguments for asynchronous vLLM engine."""
1116
    disable_log_requests: bool = False
1117
1118

    @staticmethod
1119
1120
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1121
1122
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1123
1124
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1125
                            help='Disable logging requests.')
1126
        return parser
1127
1128


1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
class StoreBoolean(argparse.Action):

    def __call__(self, parser, namespace, values, option_string=None):
        if values.lower() == "true":
            setattr(namespace, self.dest, True)
        elif values.lower() == "false":
            setattr(namespace, self.dest, False)
        else:
            raise ValueError(f"Invalid boolean value: {values}. "
                             "Expected 'true' or 'false'.")


1141
1142
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1143
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1144
1145
1146


def _async_engine_args_parser():
1147
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1148
                                        async_args_only=True)