arg_utils.py 58.8 KB
Newer Older
1
import argparse
2
import dataclasses
3
import json
4
from dataclasses import dataclass
5
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
6
                    Tuple, Type, Union, cast, get_args)
7

8
9
import torch

10
import vllm.envs as envs
11
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
12
13
14
15
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
                         ModelConfig, ObservabilityConfig, ParallelConfig,
                         PoolerConfig, PromptAdapterConfig, SchedulerConfig,
16
17
                         SpeculativeConfig, TaskOption, TokenizerPoolConfig,
                         VllmConfig)
18
from vllm.executor.executor_base import ExecutorBase
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
21
from vllm.platforms import current_platform
22
from vllm.transformers_utils.utils import check_gguf_file
23
from vllm.usage.usage_lib import UsageContext
24
from vllm.utils import FlexibleArgumentParser, StoreBoolean
25

26
if TYPE_CHECKING:
27
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
28

29
30
logger = init_logger(__name__)

31
32
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

33
34
35
36
37
38
39
40
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
41
    "hpu",
42
43
]

44

45
46
47
48
49
50
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


51
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
52
53
54
55
56
57
58
59
60
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
61
62
63
64
65
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
66
67
68
69
70
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
71
72

        try:
73
            parsed_value = int(value)
74
75
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
76
77
78
79
80
81
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
82
83
84
85

    return out_dict


86
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
87
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
88
    """Arguments for vLLM engine."""
89
    model: str = 'facebook/opt-125m'
90
    served_model_name: Optional[Union[str, List[str]]] = None
91
    tokenizer: Optional[str] = None
92
    task: TaskOption = "auto"
93
    skip_tokenizer_init: bool = False
94
    tokenizer_mode: str = 'auto'
95
    trust_remote_code: bool = False
96
    allowed_local_media_path: str = ""
97
    download_dir: Optional[str] = None
98
    load_format: str = 'auto'
99
    config_format: ConfigFormat = ConfigFormat.AUTO
100
    dtype: str = 'auto'
101
    kv_cache_dtype: str = 'auto'
102
    quantization_param_path: Optional[str] = None
103
    seed: int = 0
104
    max_model_len: Optional[int] = None
105
    worker_use_ray: bool = False
106
107
108
109
110
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
111
    # number of P/D disaggregation (or other disaggregation) workers
112
113
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
114
    max_parallel_loading_workers: Optional[int] = None
115
116
117
    # NOTE(kzawora): default block size for Gaudi should be 128
    # smaller sizes still work, but very inefficiently
    block_size: int = 16 if not current_platform.is_hpu() else 128
118
    enable_prefix_caching: Optional[bool] = None
119
    disable_sliding_window: bool = False
120
    use_v2_block_manager: bool = True
121
122
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
123
    gpu_memory_utilization: float = 0.90
124
    max_num_batched_tokens: Optional[int] = None
125
    max_num_seqs: int = 256
126
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
127
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
128
    revision: Optional[str] = None
129
    code_revision: Optional[str] = None
130
    rope_scaling: Optional[Dict[str, Any]] = None
131
    rope_theta: Optional[float] = None
132
    hf_overrides: Optional[HfOverrides] = None
133
    tokenizer_revision: Optional[str] = None
134
    quantization: Optional[str] = None
135
    enforce_eager: Optional[bool] = None
136
    max_seq_len_to_capture: int = 8192
137
    disable_custom_all_reduce: bool = False
138
    tokenizer_pool_size: int = 0
139
140
141
142
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
143
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
144
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
145
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
146
    enable_lora: bool = False
147
    enable_lora_bias: bool = False
148
149
    max_loras: int = 1
    max_lora_rank: int = 16
150
151
152
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
153
    fully_sharded_loras: bool = False
154
    lora_extra_vocab_size: int = 256
155
    long_lora_scaling_factors: Optional[Tuple[float]] = None
156
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
157
    max_cpu_loras: Optional[int] = None
158
    device: str = 'auto'
159
    num_scheduler_steps: int = 1
160
    multi_step_stream_outputs: bool = True
161
    ray_workers_use_nsight: bool = False
162
    num_gpu_blocks_override: Optional[int] = None
163
    num_lookahead_slots: int = 0
164
    model_loader_extra_config: Optional[dict] = None
165
    ignore_patterns: Optional[Union[str, List[str]]] = None
166
    preemption_mode: Optional[str] = None
167

168
    scheduler_delay_factor: float = 0.0
169
    enable_chunked_prefill: Optional[bool] = None
170

171
    guided_decoding_backend: str = 'xgrammar'
172
173
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
174
    speculative_model_quantization: Optional[str] = None
175
    speculative_draft_tensor_parallel_size: Optional[int] = None
176
    num_speculative_tokens: Optional[int] = None
177
    speculative_disable_mqa_scorer: Optional[bool] = False
178
    speculative_max_model_len: Optional[int] = None
179
    speculative_disable_by_batch_size: Optional[int] = None
180
181
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
182
183
184
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
185
    qlora_adapter_name_or_path: Optional[str] = None
186
    disable_logprobs_during_spec_decoding: Optional[bool] = None
187

188
    otlp_traces_endpoint: Optional[str] = None
189
    collect_detailed_traces: Optional[str] = None
190
    disable_async_output_proc: bool = False
191
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
192

193
194
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
195
    compilation_config: Optional[CompilationConfig] = None
196
    worker_cls: str = "auto"
197

198
199
    kv_transfer_config: Optional[KVTransferConfig] = None

200
    def __post_init__(self):
201
        if not self.tokenizer:
202
            self.tokenizer = self.model
203

204
205
206
207
208
        # Override the default value of enable_prefix_caching if it's not set
        # by user.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = bool(envs.VLLM_USE_V1)

209
210
211
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
212
213
214
215
        if isinstance(self.compilation_config, (int)):
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
        elif isinstance(self.compilation_config, (dict)):
216
217
218
            self.compilation_config = CompilationConfig.from_cli(
                json.dumps(self.compilation_config))

219
        # Setup plugins
220
221
        from vllm.plugins import load_general_plugins
        load_general_plugins()
222
223

    @staticmethod
224
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
225
        """Shared CLI arguments for vLLM engine."""
226

227
        # Model arguments
228
229
230
        parser.add_argument(
            '--model',
            type=str,
231
            default=EngineArgs.model,
232
            help='Name or path of the huggingface model to use.')
233
234
235
236
237
238
239
240
241
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
            'multiple tasks. When the model only supports one task, "auto" '
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
242
243
        parser.add_argument(
            '--tokenizer',
244
            type=nullable_str,
245
            default=EngineArgs.tokenizer,
246
247
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
248
249
250
251
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
            help='Skip initialization of tokenizer and detokenizer')
Jasmond L's avatar
Jasmond L committed
252
253
        parser.add_argument(
            '--revision',
254
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
255
            default=None,
256
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
257
258
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
259
260
        parser.add_argument(
            '--code-revision',
261
            type=nullable_str,
262
            default=None,
263
            help='The specific revision to use for the model code on '
264
265
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
266
267
        parser.add_argument(
            '--tokenizer-revision',
268
            type=nullable_str,
269
            default=None,
270
271
272
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
273
274
275
276
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
277
            choices=['auto', 'slow', 'mistral'],
278
279
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
280
281
            'always use the slow tokenizer. \n* '
            '"mistral" will always use the `mistral_common` tokenizer.')
282
283
        parser.add_argument('--trust-remote-code',
                            action='store_true',
284
                            help='Trust remote code from huggingface.')
285
286
287
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
288
289
290
291
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
292
        parser.add_argument('--download-dir',
293
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
294
                            default=EngineArgs.download_dir,
295
                            help='Directory to download and load the weights, '
296
                            'default to the default cache dir of '
297
                            'huggingface.')
298
299
300
301
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
302
            choices=[f.value for f in LoadFormat],
303
304
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
305
            'and fall back to the pytorch bin format if safetensors format '
306
307
308
309
310
311
312
313
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
314
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
315
316
317
            'section for more information.\n'
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
318
319
320
321
322
323
324
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
325
326
327
328
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
329
330
331
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
332
333
334
335
336
337
338
339
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
340
341
342
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
343
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
344
            default=EngineArgs.kv_cache_dtype,
345
            help='Data type for kv cache storage. If "auto", will use model '
346
347
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
348
349
        parser.add_argument(
            '--quantization-param-path',
350
            type=nullable_str,
351
352
353
354
355
            default=None,
            help='Path to the JSON file containing the KV cache '
            'scaling factors. This should generally be supplied, when '
            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
            'default to 1.0, which may cause accuracy issues. '
356
            'FP8_E5M2 (without scaling) is only supported on cuda version '
357
            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
358
            'supported for common inference criteria.')
359
360
        parser.add_argument('--max-model-len',
                            type=int,
361
                            default=EngineArgs.max_model_len,
362
363
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
364
365
366
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
367
368
            default='xgrammar',
            choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
369
            help='Which engine will be used for guided decoding'
370
            ' (JSON schema / regex etc) by default. Currently support '
371
372
            'https://github.com/outlines-dev/outlines,'
            'https://github.com/mlc-ai/xgrammar, and '
373
374
375
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
376
        # Parallel arguments
377
378
379
380
        parser.add_argument(
            '--distributed-executor-backend',
            choices=['ray', 'mp'],
            default=EngineArgs.distributed_executor_backend,
381
382
383
384
385
386
387
388
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
            'and hpu only support Ray for distributed inference.')

389
390
391
392
        parser.add_argument(
            '--worker-use-ray',
            action='store_true',
            help='Deprecated, use --distributed-executor-backend=ray.')
393
394
395
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
396
                            default=EngineArgs.pipeline_parallel_size,
397
                            help='Number of pipeline stages.')
398
399
400
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
401
                            default=EngineArgs.tensor_parallel_size,
402
                            help='Number of tensor parallel replicas.')
403
404
405
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
406
            default=EngineArgs.max_parallel_loading_workers,
407
            help='Load model sequentially in multiple batches, '
408
            'to avoid RAM OOM when using tensor '
409
            'parallel and large models.')
410
411
412
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
413
            help='If specified, use nsight to profile Ray workers.')
414
        # KV cache arguments
415
416
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
417
                            default=EngineArgs.block_size,
418
                            choices=[8, 16, 32, 64, 128],
419
                            help='Token block size for contiguous chunks of '
420
421
                            'tokens. This is ignored on neuron devices and '
                            'set to max-model-len')
422

423
424
425
426
427
428
429
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
            "Use --no-enable-prefix-caching to disable explicitly.",
        )
430
431
432
433
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
                            'capping to sliding window size')
434
435
436
437
438
439
440
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
441
442
443
444
445
446
447
448
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
449

450
451
452
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
453
                            help='Random seed for operations.')
454
        parser.add_argument('--swap-space',
455
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
456
                            default=EngineArgs.swap_space,
457
                            help='CPU swap space size (GiB) per GPU.')
458
459
460
461
462
463
464
465
466
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
467
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
468
            'which requires at least 26GB GPU memory. Note that this '
469
            'requires fast CPU-GPU interconnect, as part of the model is '
470
471
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
472
473
474
475
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
476
477
478
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
479
480
481
482
483
            'will use the default value of 0.9. This is a global gpu memory '
            'utilization limit, for example if 50%% of the gpu memory is '
            'already used before vLLM starts and --gpu-memory-utilization is '
            'set to 0.9, then only 40%% of the gpu memory will be allocated '
            'to the model executor.')
484
        parser.add_argument(
485
            '--num-gpu-blocks-override',
486
487
488
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
489
            ' of GPU blocks. Used for testing preemption.')
490
491
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
492
                            default=EngineArgs.max_num_batched_tokens,
493
494
                            help='Maximum number of batched tokens per '
                            'iteration.')
495
496
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
497
                            default=EngineArgs.max_num_seqs,
498
                            help='Maximum number of sequences per iteration.')
499
500
501
502
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
503
504
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
505
506
        parser.add_argument('--disable-log-stats',
                            action='store_true',
507
                            help='Disable logging statistics.')
508
509
510
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
511
                            type=nullable_str,
512
                            choices=[*QUANTIZATION_METHODS, None],
513
                            default=EngineArgs.quantization,
514
515
516
517
518
519
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
520
521
522
523
524
525
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
            'For example, {"rope_type":"dynamic","factor":2.0}')
526
527
528
529
530
531
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
532
533
534
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
535
                            help='Extra arguments for the HuggingFace config. '
536
537
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
538
539
540
541
542
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
543
        parser.add_argument('--max-seq-len-to-capture',
544
545
546
547
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
548
549
550
551
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
552
553
554
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
555
                            help='See ParallelConfig.')
556
557
558
559
560
561
562
563
564
565
566
567
568
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
569
                            type=nullable_str,
570
571
572
573
574
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
575
576
577
578
579
580
581
582
583
584
585
586
587
588

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
589
590
591
592
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
593
            help=('Overrides for the multimodal input mapping/processing, '
594
                  'e.g., image processor. For example: {"num_crops": 4}.'))
595

596
597
598
599
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
600
601
602
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
622
            choices=['auto', 'float16', 'bfloat16'],
623
624
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
625
626
627
628
629
630
631
632
633
634
635
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
636
637
638
639
640
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
641
642
                  'Must be >= than max_loras. '
                  'Defaults to max_loras.'))
643
644
645
646
647
648
649
650
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
651
652
653
654
655
656
657
658
659
660
661
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
662
663
664
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
665
                            choices=DEVICE_OPTIONS,
666
                            help='Device type for vLLM execution.')
667
668
669
670
671
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
672

673
674
        parser.add_argument(
            '--multi-step-stream-outputs',
675
676
677
678
679
680
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
681
682
683
684
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
685
            help='Apply a delay (of delay factor multiplied by previous '
686
            'prompt latency) before scheduling next prompt.')
687
688
        parser.add_argument(
            '--enable-chunked-prefill',
689
690
691
692
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
693
            help='If set, the prefill requests can be chunked based on the '
694
            'max_num_batched_tokens.')
695
696
697

        parser.add_argument(
            '--speculative-model',
698
            type=nullable_str,
699
            default=EngineArgs.speculative_model,
700
701
            help=
            'The name of the draft model to be used in speculative decoding.')
702
703
704
705
706
707
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
708
            help='Method used to quantize the weights of speculative model. '
709
710
711
712
713
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
714
715
716
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
717
            default=EngineArgs.num_speculative_tokens,
718
            help='The number of speculative tokens to sample from '
719
            'the draft model in speculative decoding.')
720
721
722
723
724
725
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
726
727
728
729
730
731
732
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
733

734
735
        parser.add_argument(
            '--speculative-max-model-len',
736
            type=int,
737
738
739
740
741
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

742
743
744
745
746
747
748
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

749
750
751
752
753
754
755
756
757
758
759
760
761
762
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

795
796
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
797
            action=StoreBoolean,
798
            default=EngineArgs.disable_logprobs_during_spec_decoding,
799
800
            nargs="?",
            const="True",
801
802
803
804
805
806
807
808
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

809
        parser.add_argument('--model-loader-extra-config',
810
                            type=nullable_str,
811
812
813
814
815
816
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
817
818
819
820
821
822
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
823
            "Default to `original/**/*` to avoid repeated loading of llama's "
824
            "checkpoints.")
825
        parser.add_argument(
826
            '--preemption-mode',
827
828
            type=str,
            default=None,
829
830
831
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
832

833
834
835
836
837
838
839
840
841
842
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
843
            "same as the `--model` argument. Noted that this name(s) "
844
            "will also be used in `model_name` tag content of "
845
            "prometheus metrics, if multiple names provided, metrics "
846
            "tag will take the first one.")
847
848
849
850
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
851
852
853
854
855
856

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
857
858
859
860
861
862
863
864
865
866
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
            ". It makes sense to set this only if --otlp-traces-endpoint is"
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
867

868
869
870
871
872
873
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
874

875
876
877
878
879
880
881
882
883
884
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

885
        parser.add_argument(
886
887
            '--override-neuron-config',
            type=json.loads,
888
            default=None,
889
890
            help="Override or set neuron device configuration. "
            "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'")
891
        parser.add_argument(
892
893
            '--override-pooler-config',
            type=PoolerConfig.from_json,
894
            default=None,
895
896
            help="Override or set the pooling method in the embedding model. "
            "e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'")
897

898
899
900
901
902
903
904
905
906
907
908
909
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
910
911
912
913
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
914

915
916
917
918
919
920
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

921
922
923
924
925
926
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')

927
        return parser
928
929

    @classmethod
930
    def from_cli_args(cls, args: argparse.Namespace):
931
932
933
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
934
935
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
936

937
938
    def create_model_config(self) -> ModelConfig:
        return ModelConfig(
939
            model=self.model,
940
            task=self.task,
941
942
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
943
944
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
945
            allowed_local_media_path=self.allowed_local_media_path,
946
947
948
949
950
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
951
            rope_theta=self.rope_theta,
952
            hf_overrides=self.hf_overrides,
953
954
955
956
957
958
959
960
961
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            quantization_param_path=self.quantization_param_path,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
962
            served_model_name=self.served_model_name,
963
            limit_mm_per_prompt=self.limit_mm_per_prompt,
964
            use_async_output_proc=not self.disable_async_output_proc,
965
            config_format=self.config_format,
966
            mm_processor_kwargs=self.mm_processor_kwargs,
967
968
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
969
970
        )

971
972
973
974
975
976
977
978
    def create_load_config(self) -> LoadConfig:
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
        )

979
980
981
982
983
984
    def create_engine_config(self,
                             usage_context: Optional[UsageContext] = None
                             ) -> VllmConfig:
        if envs.VLLM_USE_V1:
            self._override_v1_engine_args(usage_context)

985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

        device_config = DeviceConfig(device=self.device)
        model_config = self.create_model_config()

1012
1013
1014
1015
1016
1017
1018
        if model_config.is_multimodal_model:
            if self.enable_prefix_caching:
                logger.warning(
                    "--enable-prefix-caching is currently not "
                    "supported for multimodal models and has been disabled.")
            self.enable_prefix_caching = False

1019
        cache_config = CacheConfig(
1020
            # neuron needs block_size = max_model_len
1021
            block_size=self.block_size if self.device != "neuron" else
1022
            (self.max_model_len if self.max_model_len is not None else 0),
1023
1024
1025
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1026
            is_attention_free=model_config.is_attention_free,
1027
1028
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1029
1030
1031
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
        )
1032
        parallel_config = ParallelConfig(
1033
1034
1035
1036
1037
1038
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            worker_use_ray=self.worker_use_ray,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1039
1040
1041
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1042
            ),
1043
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1044
1045
1046
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
        )
1047

1048
1049
1050
1051
1052
1053
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
1054
1055
1056
1057

            # Chunked prefill is currently disabled for multimodal models by
            # default.
            if use_long_context and not model_config.is_multimodal_model:
1058
1059
1060
1061
1062
1063
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
1064
1065
                        and not self.enable_prompt_adapter
                        and model_config.task != "embedding"):
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
1082
1083
1084
        elif self.enable_chunked_prefill and model_config.task == "embedding":
            msg = "Chunked prefill is not supported for embedding models"
            raise ValueError(msg)
1085

1086

1087
1088
1089
1090
1091
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
1092
1093
            speculative_model_quantization = \
                self.speculative_model_quantization,
1094
1095
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
1096
            num_speculative_tokens=self.num_speculative_tokens,
1097
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
1098
1099
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
1100
1101
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
1102
            disable_log_stats=self.disable_log_stats,
1103
1104
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1105
1106
1107
1108
1109
1110
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
1111
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
1112
1113
        )

1114
        # Reminder: Please update docs/source/usage/compatibility_matrix.rst
1115
        # If the feature combo become valid
1116
1117
1118
1119
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1120
1121
1122
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1123
1124
1125
1126
1127
1128
1129
1130
1131

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
        if not self.use_v2_block_manager:
            logger.warning(
                "[DEPRECATED] Block manager v1 has been removed, "
                "and setting --use-v2-block-manager to True or False has "
                "no effect on vLLM behavior. Please remove "
                "--use-v2-block-manager in your engine argument. "
                "If your use case is not supported by "
                "SelfAttnBlockSpaceManager (i.e. block manager v2),"
                " please file an issue with detailed information.")

1142
        scheduler_config = SchedulerConfig(
1143
            task=model_config.task,
1144
1145
1146
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1147
            num_lookahead_slots=num_lookahead_slots,
1148
1149
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1150
            is_multimodal_model=model_config.is_multimodal_model,
1151
            preemption_mode=self.preemption_mode,
1152
            num_scheduler_steps=self.num_scheduler_steps,
1153
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1154
1155
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1156
            policy=self.scheduling_policy)
1157
        lora_config = LoRAConfig(
1158
            bias_enabled=self.enable_lora_bias,
1159
1160
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1161
            fully_sharded_loras=self.fully_sharded_loras,
1162
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1163
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1164
1165
1166
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1167

1168
1169
1170
1171
1172
1173
1174
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1175
        load_config = self.create_load_config()
1176

1177
1178
1179
1180
1181
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1182
1183
1184
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

1185
1186
1187
1188
1189
1190
1191
1192
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1193
        observability_config = ObservabilityConfig(
1194
1195
1196
1197
1198
1199
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1200

1201
        config = VllmConfig(
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1212
            prompt_adapter_config=prompt_adapter_config,
1213
            compilation_config=self.compilation_config,
1214
            kv_transfer_config=self.kv_transfer_config,
1215
        )
1216

1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
        if envs.VLLM_USE_V1:
            self._override_v1_engine_config(config)
        return config

    def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
        """
        Override the EngineArgs's args based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

        if self.max_num_batched_tokens is None:
            # When no user override, set the default values based on the
            # usage context.
            if usage_context == UsageContext.LLM_CLASS:
                logger.warning("Setting max_num_batched_tokens to 8192 "
                               "for LLM_CLASS usage context.")
                self.max_num_seqs = 1024
                self.max_num_batched_tokens = 8192
            elif usage_context == UsageContext.OPENAI_API_SERVER:
                logger.warning("Setting max_num_batched_tokens to 2048 "
                               "for OPENAI_API_SERVER usage context.")
                self.max_num_seqs = 1024
                self.max_num_batched_tokens = 2048

    def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
        """
        Override the EngineConfig's configs based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"
        # TODO (ywang96): Enable APC by default when VLM supports it.
        if engine_config.model_config.is_multimodal_model:
            logger.warning(
                "Prefix caching is currently not supported for multimodal "
                "models and has been disabled.")
            engine_config.cache_config.enable_prefix_caching = False

1253

1254
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1255
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1256
    """Arguments for asynchronous vLLM engine."""
1257
    disable_log_requests: bool = False
1258
1259

    @staticmethod
1260
1261
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1262
1263
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1264
1265
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1266
                            help='Disable logging requests.')
1267
        return parser
1268
1269
1270
1271


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1272
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1273
1274
1275


def _async_engine_args_parser():
1276
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1277
                                        async_args_only=True)