arg_utils.py 58 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
6
                    Tuple, Type, Union, cast, get_args)
7

8
9
import torch

10
import vllm.envs as envs
11
12
13
14
15
16
17
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
                         DecodingConfig, DeviceConfig, HfOverrides, LoadConfig,
                         LoadFormat, LoRAConfig, ModelConfig,
                         ObservabilityConfig, ParallelConfig, PoolerConfig,
                         PromptAdapterConfig, SchedulerConfig,
                         SpeculativeConfig, TaskOption, TokenizerPoolConfig,
                         VllmConfig)
18
from vllm.executor.executor_base import ExecutorBase
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
21
from vllm.platforms import current_platform
22
from vllm.transformers_utils.utils import check_gguf_file
23
from vllm.usage.usage_lib import UsageContext
24
from vllm.utils import FlexibleArgumentParser, StoreBoolean
25

26
if TYPE_CHECKING:
27
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
28

29
30
logger = init_logger(__name__)

31
32
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

33
34
35
36
37
38
39
40
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
41
    "hpu",
42
43
]

44

45
46
47
48
49
50
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


51
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
52
53
54
55
56
57
58
59
60
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
61
62
63
64
65
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
66
67
68
69
70
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
71
72

        try:
73
            parsed_value = int(value)
74
75
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
76
77
78
79
80
81
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
82
83
84
85

    return out_dict


86
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
87
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
88
    """Arguments for vLLM engine."""
89
    model: str = 'facebook/opt-125m'
90
    served_model_name: Optional[Union[str, List[str]]] = None
91
    tokenizer: Optional[str] = None
92
    task: TaskOption = "auto"
93
    skip_tokenizer_init: bool = False
94
    tokenizer_mode: str = 'auto'
95
    trust_remote_code: bool = False
96
    allowed_local_media_path: str = ""
97
    download_dir: Optional[str] = None
98
    load_format: str = 'auto'
99
    config_format: ConfigFormat = ConfigFormat.AUTO
100
    dtype: str = 'auto'
101
    kv_cache_dtype: str = 'auto'
102
    quantization_param_path: Optional[str] = None
103
    seed: int = 0
104
    max_model_len: Optional[int] = None
105
    worker_use_ray: bool = False
106
107
108
109
110
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
111
112
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
113
    max_parallel_loading_workers: Optional[int] = None
114
115
116
    # NOTE(kzawora): default block size for Gaudi should be 128
    # smaller sizes still work, but very inefficiently
    block_size: int = 16 if not current_platform.is_hpu() else 128
117
    enable_prefix_caching: Optional[bool] = None
118
    disable_sliding_window: bool = False
119
    use_v2_block_manager: bool = True
120
121
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
122
    gpu_memory_utilization: float = 0.90
123
    max_num_batched_tokens: Optional[int] = None
124
    max_num_seqs: int = 256
125
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
126
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
127
    revision: Optional[str] = None
128
    code_revision: Optional[str] = None
129
    rope_scaling: Optional[Dict[str, Any]] = None
130
    rope_theta: Optional[float] = None
131
    hf_overrides: Optional[HfOverrides] = None
132
    tokenizer_revision: Optional[str] = None
133
    quantization: Optional[str] = None
134
    enforce_eager: Optional[bool] = None
135
    max_seq_len_to_capture: int = 8192
136
    disable_custom_all_reduce: bool = False
137
    tokenizer_pool_size: int = 0
138
139
140
141
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
142
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
143
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
144
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
145
    enable_lora: bool = False
146
    enable_lora_bias: bool = False
147
148
    max_loras: int = 1
    max_lora_rank: int = 16
149
150
151
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
152
    fully_sharded_loras: bool = False
153
    lora_extra_vocab_size: int = 256
154
    long_lora_scaling_factors: Optional[Tuple[float]] = None
155
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
156
    max_cpu_loras: Optional[int] = None
157
    device: str = 'auto'
158
    num_scheduler_steps: int = 1
159
    multi_step_stream_outputs: bool = True
160
    ray_workers_use_nsight: bool = False
161
    num_gpu_blocks_override: Optional[int] = None
162
    num_lookahead_slots: int = 0
163
    model_loader_extra_config: Optional[dict] = None
164
    ignore_patterns: Optional[Union[str, List[str]]] = None
165
    preemption_mode: Optional[str] = None
166

167
    scheduler_delay_factor: float = 0.0
168
    enable_chunked_prefill: Optional[bool] = None
169

170
    guided_decoding_backend: str = 'outlines'
171
172
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
173
    speculative_model_quantization: Optional[str] = None
174
    speculative_draft_tensor_parallel_size: Optional[int] = None
175
    num_speculative_tokens: Optional[int] = None
176
    speculative_disable_mqa_scorer: Optional[bool] = False
177
    speculative_max_model_len: Optional[int] = None
178
    speculative_disable_by_batch_size: Optional[int] = None
179
180
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
181
182
183
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
184
    qlora_adapter_name_or_path: Optional[str] = None
185
    disable_logprobs_during_spec_decoding: Optional[bool] = None
186

187
    otlp_traces_endpoint: Optional[str] = None
188
    collect_detailed_traces: Optional[str] = None
189
    disable_async_output_proc: bool = False
190
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
191

192
193
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
194
    compilation_config: Optional[CompilationConfig] = None
195
    worker_cls: str = "auto"
196

197
    def __post_init__(self):
198
        if not self.tokenizer:
199
            self.tokenizer = self.model
200

201
202
203
204
205
        # Override the default value of enable_prefix_caching if it's not set
        # by user.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = bool(envs.VLLM_USE_V1)

206
207
208
209
210
211
212
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
        if isinstance(self.compilation_config, (int, dict)):
            self.compilation_config = CompilationConfig.from_cli(
                json.dumps(self.compilation_config))

213
        # Setup plugins
214
215
        from vllm.plugins import load_general_plugins
        load_general_plugins()
216
217

    @staticmethod
218
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
219
        """Shared CLI arguments for vLLM engine."""
220

221
        # Model arguments
222
223
224
        parser.add_argument(
            '--model',
            type=str,
225
            default=EngineArgs.model,
226
            help='Name or path of the huggingface model to use.')
227
228
229
230
231
232
233
234
235
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
            'multiple tasks. When the model only supports one task, "auto" '
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
236
237
        parser.add_argument(
            '--tokenizer',
238
            type=nullable_str,
239
            default=EngineArgs.tokenizer,
240
241
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
242
243
244
245
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
246
247
        parser.add_argument(
            '--revision',
248
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
249
            default=None,
250
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
251
252
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
253
254
        parser.add_argument(
            '--code-revision',
255
            type=nullable_str,
256
            default=None,
257
            help='The specific revision to use for the model code on '
258
259
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
260
261
        parser.add_argument(
            '--tokenizer-revision',
262
            type=nullable_str,
263
            default=None,
264
265
266
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
267
268
269
270
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
271
            choices=['auto', 'slow', 'mistral'],
272
273
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
274
275
            'always use the slow tokenizer. \n* '
            '"mistral" will always use the `mistral_common` tokenizer.')
276
277
        parser.add_argument('--trust-remote-code',
                            action='store_true',
278
                            help='Trust remote code from huggingface.')
279
280
281
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
282
283
284
285
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
286
        parser.add_argument('--download-dir',
287
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
288
                            default=EngineArgs.download_dir,
289
                            help='Directory to download and load the weights, '
290
                            'default to the default cache dir of '
291
                            'huggingface.')
292
293
294
295
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
296
            choices=[f.value for f in LoadFormat],
297
298
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
299
            'and fall back to the pytorch bin format if safetensors format '
300
301
302
303
304
305
306
307
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
308
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
309
310
311
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
312
313
314
315
316
317
318
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
319
320
321
322
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
323
324
325
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
326
327
328
329
330
331
332
333
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
334
335
336
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
337
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
338
            default=EngineArgs.kv_cache_dtype,
339
            help='Data type for kv cache storage. If "auto", will use model '
340
341
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
342
343
        parser.add_argument(
            '--quantization-param-path',
344
            type=nullable_str,
345
346
347
348
349
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
350
            'FP8_E5M2 (without scaling) is only supported on cuda version '
351
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
352
            'supported for common inference criteria.')
353
354
        parser.add_argument('--max-model-len',
                            type=int,
355
                            default=EngineArgs.max_model_len,
356
357
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
358
359
360
361
362
363
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
            default='outlines',
            choices=['outlines', 'lm-format-enforcer'],
            help='Which engine will be used for guided decoding'
364
365
366
367
368
            ' (JSON schema / regex etc) by default. Currently support '
            'https://github.com/outlines-dev/outlines and '
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
369
        # Parallel arguments
370
371
372
373
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
374
375
376
377
378
379
380
381
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
            'and hpu only support Ray for distributed inference.')

382
383
384
385
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
386
387
388
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
389
                            default=EngineArgs.pipeline_parallel_size,
390
                            help='Number of pipeline stages.')
391
392
393
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
394
                            default=EngineArgs.tensor_parallel_size,
395
                            help='Number of tensor parallel replicas.')
396
397
398
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
399
            default=EngineArgs.max_parallel_loading_workers,
400
            help='Load model sequentially in multiple batches, '
401
            'to avoid RAM OOM when using tensor '
402
            'parallel and large models.')
403
404
405
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
406
            help='If specified, use nsight to profile Ray workers.')
407
        # KV cache arguments
408
409
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
410
                            default=EngineArgs.block_size,
411
                            choices=[8, 16, 32, 64, 128],
412
                            help='Token block size for contiguous chunks of '
413
414
                            'tokens. This is ignored on neuron devices and '
                            'set to max-model-len')
415
416
417

        parser.add_argument('--enable-prefix-caching',
                            action='store_true',
418
                            help='Enables automatic prefix caching.')
419
420
421
422
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
423
424
425
426
427
428
429
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
430
431
432
433
434
435
436
437
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
438

439
440
441
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
442
                            help='Random seed for operations.')
443
        parser.add_argument('--swap-space',
444
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
445
                            default=EngineArgs.swap_space,
446
                            help='CPU swap space size (GiB) per GPU.')
447
448
449
450
451
452
453
454
455
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
456
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
457
            'which requires at least 26GB GPU memory. Note that this '
458
            'requires fast CPU-GPU interconnect, as part of the model is '
459
460
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
461
462
463
464
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
465
466
467
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
468
469
470
471
472
            'will use the default value of 0.9. This is a global gpu memory '
            'utilization limit, for example if 50%% of the gpu memory is '
            'already used before vLLM starts and --gpu-memory-utilization is '
            'set to 0.9, then only 40%% of the gpu memory will be allocated '
            'to the model executor.')
473
        parser.add_argument(
474
            '--num-gpu-blocks-override',
475
476
477
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
478
            ' of GPU blocks. Used for testing preemption.')
479
480
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
481
                            default=EngineArgs.max_num_batched_tokens,
482
483
                            help='Maximum number of batched tokens per '
                            'iteration.')
484
485
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
486
                            default=EngineArgs.max_num_seqs,
487
                            help='Maximum number of sequences per iteration.')
488
489
490
491
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
492
493
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
494
495
        parser.add_argument('--disable-log-stats',
                            action='store_true',
496
                            help='Disable logging statistics.')
497
498
499
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
500
                            type=nullable_str,
501
                            choices=[*QUANTIZATION_METHODS, None],
502
                            default=EngineArgs.quantization,
503
504
505
506
507
508
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
509
510
511
512
513
514
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
            'For example, {"rope_type":"dynamic","factor":2.0}')
515
516
517
518
519
520
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
521
522
523
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
524
                            help='Extra arguments for the HuggingFace config. '
525
526
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
527
528
529
530
531
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
532
        parser.add_argument('--max-seq-len-to-capture',
533
534
535
536
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
537
538
539
540
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
541
542
543
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
544
                            help='See ParallelConfig.')
545
546
547
548
549
550
551
552
553
554
555
556
557
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
558
                            type=nullable_str,
559
560
561
562
563
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
564
565
566
567
568
569
570
571
572
573
574
575
576
577

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
578
579
580
581
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
582
            help=('Overrides for the multimodal input mapping/processing, '
583
                  'e.g., image processor. For example: {"num_crops": 4}.'))
584

585
586
587
588
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
589
590
591
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
611
            choices=['auto', 'float16', 'bfloat16'],
612
613
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
614
615
616
617
618
619
620
621
622
623
624
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
625
626
627
628
629
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
630
631
                  'Must be >= than max_loras. '
                  'Defaults to max_loras.'))
632
633
634
635
636
637
638
639
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
640
641
642
643
644
645
646
647
648
649
650
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
651
652
653
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
654
                            choices=DEVICE_OPTIONS,
655
                            help='Device type for vLLM execution.')
656
657
658
659
660
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
661

662
663
        parser.add_argument(
            '--multi-step-stream-outputs',
664
665
666
667
668
669
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
670
671
672
673
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
674
            help='Apply a delay (of delay factor multiplied by previous '
675
            'prompt latency) before scheduling next prompt.')
676
677
        parser.add_argument(
            '--enable-chunked-prefill',
678
679
680
681
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
682
            help='If set, the prefill requests can be chunked based on the '
683
            'max_num_batched_tokens.')
684
685
686

        parser.add_argument(
            '--speculative-model',
687
            type=nullable_str,
688
            default=EngineArgs.speculative_model,
689
690
            help=
            'The name of the draft model to be used in speculative decoding.')
691
692
693
694
695
696
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
697
            help='Method used to quantize the weights of speculative model. '
698
699
700
701
702
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
703
704
705
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
706
            default=EngineArgs.num_speculative_tokens,
707
            help='The number of speculative tokens to sample from '
708
            'the draft model in speculative decoding.')
709
710
711
712
713
714
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
715
716
717
718
719
720
721
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
722

723
724
        parser.add_argument(
            '--speculative-max-model-len',
725
            type=int,
726
727
728
729
730
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

731
732
733
734
735
736
737
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

738
739
740
741
742
743
744
745
746
747
748
749
750
751
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

784
785
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
786
            action=StoreBoolean,
787
            default=EngineArgs.disable_logprobs_during_spec_decoding,
788
789
            nargs="?",
            const="True",
790
791
792
793
794
795
796
797
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

798
        parser.add_argument('--model-loader-extra-config',
799
                            type=nullable_str,
800
801
802
803
804
805
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
806
807
808
809
810
811
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
812
            "Default to `original/**/*` to avoid repeated loading of llama's "
813
            "checkpoints.")
814
        parser.add_argument(
815
            '--preemption-mode',
816
817
            type=str,
            default=None,
818
819
820
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
821

822
823
824
825
826
827
828
829
830
831
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
832
            "same as the `--model` argument. Noted that this name(s) "
833
            "will also be used in `model_name` tag content of "
834
            "prometheus metrics, if multiple names provided, metrics "
835
            "tag will take the first one.")
836
837
838
839
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
840
841
842
843
844
845

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
846
847
848
849
850
851
852
853
854
855
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
            ". It makes sense to set this only if --otlp-traces-endpoint is"
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
856

857
858
859
860
861
862
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
863

864
865
866
867
868
869
870
871
872
873
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

874
        parser.add_argument(
875
876
            '--override-neuron-config',
            type=json.loads,
877
            default=None,
878
879
            help="Override or set neuron device configuration. "
            "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'")
880
        parser.add_argument(
881
882
            '--override-pooler-config',
            type=PoolerConfig.from_json,
883
            default=None,
884
885
            help="Override or set the pooling method in the embedding model. "
            "e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'")
886

887
888
889
890
891
892
893
894
895
896
897
898
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
899
900
901
902
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
903

904
905
906
907
908
909
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')

910
        return parser
911
912

    @classmethod
913
    def from_cli_args(cls, args: argparse.Namespace):
914
915
916
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
917
918
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
919

920
921
    def create_model_config(self) -> ModelConfig:
        return ModelConfig(
922
            model=self.model,
923
            task=self.task,
924
925
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
926
927
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
928
            allowed_local_media_path=self.allowed_local_media_path,
929
930
931
932
933
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
934
            rope_theta=self.rope_theta,
935
            hf_overrides=self.hf_overrides,
936
937
938
939
940
941
942
943
944
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
945
            served_model_name=self.served_model_name,
946
            limit_mm_per_prompt=self.limit_mm_per_prompt,
947
            use_async_output_proc=not self.disable_async_output_proc,
948
            config_format=self.config_format,
949
            mm_processor_kwargs=self.mm_processor_kwargs,
950
951
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
952
953
        )

954
955
956
957
958
959
960
961
    def create_load_config(self) -> LoadConfig:
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
        )

962
963
964
965
966
967
    def create_engine_config(self,
                             usage_context: Optional[UsageContext] = None
                             ) -> VllmConfig:
        if envs.VLLM_USE_V1:
            self._override_v1_engine_args(usage_context)

968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

        device_config = DeviceConfig(device=self.device)
        model_config = self.create_model_config()

995
996
997
998
999
1000
1001
        if model_config.is_multimodal_model:
            if self.enable_prefix_caching:
                logger.warning(
                    "--enable-prefix-caching is currently not "
                    "supported for multimodal models and has been disabled.")
            self.enable_prefix_caching = False

1002
        cache_config = CacheConfig(
1003
            # neuron needs block_size = max_model_len
1004
            block_size=self.block_size if self.device != "neuron" else
1005
            (self.max_model_len if self.max_model_len is not None else 0),
1006
1007
1008
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1009
            is_attention_free=model_config.is_attention_free,
1010
1011
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1012
1013
1014
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
        )
1015
        parallel_config = ParallelConfig(
1016
1017
1018
1019
1020
1021
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1022
1023
1024
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1025
            ),
1026
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1027
1028
1029
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
        )
1030

1031
1032
1033
1034
1035
1036
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
1037
1038
1039
1040

            # Chunked prefill is currently disabled for multimodal models by
            # default.
            if use_long_context and not model_config.is_multimodal_model:
1041
1042
1043
1044
1045
1046
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
1047
1048
                        and not self.enable_prompt_adapter
                        and model_config.task != "embedding"):
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
1065
1066
1067
        elif self.enable_chunked_prefill and model_config.task == "embedding":
            msg = "Chunked prefill is not supported for embedding models"
            raise ValueError(msg)
1068

1069

1070
1071
1072
1073
1074
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
1075
1076
            speculative_model_quantization = \
                self.speculative_model_quantization,
1077
1078
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
1079
            num_speculative_tokens=self.num_speculative_tokens,
1080
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
1081
1082
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
1083
1084
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
1085
            disable_log_stats=self.disable_log_stats,
1086
1087
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1088
1089
1090
1091
1092
1093
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
1094
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
1095
1096
        )

1097
1098
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
1099
1100
1101
1102
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1103
1104
1105
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1106
1107
1108
1109
1110
1111
1112
1113
1114

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
        if not self.use_v2_block_manager:
            logger.warning(
                "[DEPRECATED] Block manager v1 has been removed, "
                "and setting --use-v2-block-manager to True or False has "
                "no effect on vLLM behavior. Please remove "
                "--use-v2-block-manager in your engine argument. "
                "If your use case is not supported by "
                "SelfAttnBlockSpaceManager (i.e. block manager v2),"
                " please file an issue with detailed information.")

1125
        scheduler_config = SchedulerConfig(
1126
            task=model_config.task,
1127
1128
1129
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1130
            num_lookahead_slots=num_lookahead_slots,
1131
1132
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1133
            is_multimodal_model=model_config.is_multimodal_model,
1134
            preemption_mode=self.preemption_mode,
1135
            num_scheduler_steps=self.num_scheduler_steps,
1136
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1137
1138
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1139
            policy=self.scheduling_policy)
1140
        lora_config = LoRAConfig(
1141
            bias_enabled=self.enable_lora_bias,
1142
1143
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1144
            fully_sharded_loras=self.fully_sharded_loras,
1145
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1146
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1147
1148
1149
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1150

1151
1152
1153
1154
1155
1156
1157
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1158
        load_config = self.create_load_config()
1159

1160
1161
1162
1163
1164
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1165
1166
1167
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

1168
1169
1170
1171
1172
1173
1174
1175
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1176
        observability_config = ObservabilityConfig(
1177
1178
1179
1180
1181
1182
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1183

1184
        config = VllmConfig(
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1195
            prompt_adapter_config=prompt_adapter_config,
1196
            compilation_config=self.compilation_config,
1197
        )
1198

1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
        if envs.VLLM_USE_V1:
            self._override_v1_engine_config(config)
        return config

    def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
        """
        Override the EngineArgs's args based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

        if self.max_num_batched_tokens is None:
            # When no user override, set the default values based on the
            # usage context.
            if usage_context == UsageContext.LLM_CLASS:
                logger.warning("Setting max_num_batched_tokens to 8192 "
                               "for LLM_CLASS usage context.")
                self.max_num_seqs = 1024
                self.max_num_batched_tokens = 8192
            elif usage_context == UsageContext.OPENAI_API_SERVER:
                logger.warning("Setting max_num_batched_tokens to 2048 "
                               "for OPENAI_API_SERVER usage context.")
                self.max_num_seqs = 1024
                self.max_num_batched_tokens = 2048

    def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
        """
        Override the EngineConfig's configs based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"
        # TODO (ywang96): Enable APC by default when VLM supports it.
        if engine_config.model_config.is_multimodal_model:
            logger.warning(
                "Prefix caching is currently not supported for multimodal "
                "models and has been disabled.")
            engine_config.cache_config.enable_prefix_caching = False

1235

1236
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1237
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1238
    """Arguments for asynchronous vLLM engine."""
1239
    disable_log_requests: bool = False
1240
1241

    @staticmethod
1242
1243
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1244
1245
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1246
1247
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1248
                            help='Disable logging requests.')
1249
        return parser
1250
1251
1252
1253


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1254
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1255
1256
1257


def _async_engine_args_parser():
1258
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1259
                                        async_args_only=True)