"docs/vscode:/vscode.git/clone" did not exist on "cb506ecb5afa58670c6105b589696e6e176f60aa"
arg_utils.py 63.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import argparse
4
import dataclasses
5
import json
6
from dataclasses import dataclass
7
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
8
                    Tuple, Type, Union, cast, get_args)
9

10
11
import torch

12
import vllm.envs as envs
13
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
14
15
                         DecodingConfig, DeviceConfig, HfOverrides,
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
16
17
18
19
                         ModelConfig, ModelImpl, ObservabilityConfig,
                         ParallelConfig, PoolerConfig, PromptAdapterConfig,
                         SchedulerConfig, SpeculativeConfig, TaskOption,
                         TokenizerPoolConfig, VllmConfig)
20
from vllm.executor.executor_base import ExecutorBase
21
from vllm.logger import init_logger
22
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
23
from vllm.transformers_utils.utils import check_gguf_file
24
from vllm.usage.usage_lib import UsageContext
25
from vllm.utils import FlexibleArgumentParser, StoreBoolean
26

27
if TYPE_CHECKING:
28
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
29

30
31
logger = init_logger(__name__)

32
33
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

34
35
36
37
38
39
40
41
DEVICE_OPTIONS = [
    "auto",
    "cuda",
    "neuron",
    "cpu",
    "openvino",
    "tpu",
    "xpu",
42
    "hpu",
43
44
]

45

46
47
48
49
50
51
def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


52
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
53
54
55
56
57
58
59
60
61
    """Parses a string containing comma separate key [str] to value [int]
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
62
63
64
65
66
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
67
68
69
70
71
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
72
73

        try:
74
            parsed_value = int(value)
75
76
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
77
78
79
80
81
82
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
83
84
85
86

    return out_dict


87
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
88
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
89
    """Arguments for vLLM engine."""
90
    model: str = 'facebook/opt-125m'
91
    served_model_name: Optional[Union[str, List[str]]] = None
92
    tokenizer: Optional[str] = None
93
    task: TaskOption = "auto"
94
    skip_tokenizer_init: bool = False
95
    tokenizer_mode: str = 'auto'
96
    trust_remote_code: bool = False
97
    allowed_local_media_path: str = ""
98
    download_dir: Optional[str] = None
99
    load_format: str = 'auto'
100
    config_format: ConfigFormat = ConfigFormat.AUTO
101
    dtype: str = 'auto'
102
    kv_cache_dtype: str = 'auto'
103
    seed: int = 0
104
    max_model_len: Optional[int] = None
105
106
107
108
109
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    distributed_executor_backend: Optional[Union[str,
                                                 Type[ExecutorBase]]] = None
110
    # number of P/D disaggregation (or other disaggregation) workers
111
112
    pipeline_parallel_size: int = 1
    tensor_parallel_size: int = 1
113
    max_parallel_loading_workers: Optional[int] = None
114
    block_size: Optional[int] = None
115
    enable_prefix_caching: Optional[bool] = None
116
    disable_sliding_window: bool = False
117
    use_v2_block_manager: bool = True
118
119
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
120
    gpu_memory_utilization: float = 0.90
121
    max_num_batched_tokens: Optional[int] = None
122
    max_num_seqs: Optional[int] = None
123
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
124
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
125
    revision: Optional[str] = None
126
    code_revision: Optional[str] = None
127
    rope_scaling: Optional[Dict[str, Any]] = None
128
    rope_theta: Optional[float] = None
129
    hf_overrides: Optional[HfOverrides] = None
130
    tokenizer_revision: Optional[str] = None
131
    quantization: Optional[str] = None
132
    enforce_eager: Optional[bool] = None
133
    max_seq_len_to_capture: int = 8192
134
    disable_custom_all_reduce: bool = False
135
    tokenizer_pool_size: int = 0
136
137
138
139
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
    tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
140
    tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
141
    limit_mm_per_prompt: Optional[Mapping[str, int]] = None
142
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
143
    disable_mm_preprocessor_cache: bool = False
144
    enable_lora: bool = False
145
    enable_lora_bias: bool = False
146
147
    max_loras: int = 1
    max_lora_rank: int = 16
148
149
150
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
151
    fully_sharded_loras: bool = False
152
    lora_extra_vocab_size: int = 256
153
    long_lora_scaling_factors: Optional[Tuple[float]] = None
154
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
155
    max_cpu_loras: Optional[int] = None
156
157
    merge_lora: bool = False
    lora_target_modules: Optional[List[str]] = None
158
    device: str = 'auto'
159
    num_scheduler_steps: int = 1
160
    multi_step_stream_outputs: bool = True
161
    ray_workers_use_nsight: bool = False
162
    num_gpu_blocks_override: Optional[int] = None
163
    num_lookahead_slots: int = 0
164
    model_loader_extra_config: Optional[dict] = None
165
    ignore_patterns: Optional[Union[str, List[str]]] = None
166
    preemption_mode: Optional[str] = None
167

168
    scheduler_delay_factor: float = 0.0
169
    enable_chunked_prefill: Optional[bool] = None
170

171
    guided_decoding_backend: str = 'xgrammar'
172
    logits_processor_pattern: Optional[str] = None
173
174
    # Speculative decoding configuration.
    speculative_model: Optional[str] = None
175
    speculative_model_quantization: Optional[str] = None
176
    speculative_draft_tensor_parallel_size: Optional[int] = None
177
    num_speculative_tokens: Optional[int] = None
178
    num_speculative_heads: Optional[int] = None
179
    speculative_disable_mqa_scorer: Optional[bool] = False
180
    speculative_max_model_len: Optional[int] = None
181
    speculative_disable_by_batch_size: Optional[int] = None
182
183
    ngram_prompt_lookup_max: Optional[int] = None
    ngram_prompt_lookup_min: Optional[int] = None
184
185
186
    spec_decoding_acceptance_method: str = 'rejection_sampler'
    typical_acceptance_sampler_posterior_threshold: Optional[float] = None
    typical_acceptance_sampler_posterior_alpha: Optional[float] = None
187
    qlora_adapter_name_or_path: Optional[str] = None
188
    disable_logprobs_during_spec_decoding: Optional[bool] = None
189

190
    otlp_traces_endpoint: Optional[str] = None
191
    collect_detailed_traces: Optional[str] = None
192
    disable_async_output_proc: bool = False
193
    scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
194

195
    override_neuron_config: Optional[Dict[str, Any]] = None
196
    override_pooler_config: Optional[PoolerConfig] = None
197
    compilation_config: Optional[CompilationConfig] = None
198
    worker_cls: str = "auto"
199

200
    kv_transfer_config: Optional[KVTransferConfig] = None
201

202
    generation_config: Optional[str] = None
203
    override_generation_config: Optional[Dict[str, Any]] = None
204
    enable_sleep_mode: bool = False
205
    model_impl: str = "auto"
206

207
    calculate_kv_scales: Optional[bool] = None
208

王敏's avatar
王敏 committed
209
210
    moe_ep_size: int = 1

211
    def __post_init__(self):
212
        if not self.tokenizer:
213
            self.tokenizer = self.model
214

215
216
217
218
        # Override the default value of enable_prefix_caching if it's not set
        # by user.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
219

220
221
222
        # Override max_num_seqs if it's not set by user.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
223

224
225
226
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
227
        if isinstance(self.compilation_config, (int, dict)):
228
229
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
230

231
        # Setup plugins
232
233
        from vllm.plugins import load_general_plugins
        load_general_plugins()
234
235

    @staticmethod
236
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
237
        """Shared CLI arguments for vLLM engine."""
238

239
        # Model arguments
240
241
242
        parser.add_argument(
            '--model',
            type=str,
243
            default=EngineArgs.model,
244
            help='Name or path of the huggingface model to use.')
245
246
247
248
249
250
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
251
            'multiple tasks. When the model only supports one task, ``"auto"`` '
252
253
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
254
255
        parser.add_argument(
            '--tokenizer',
256
            type=nullable_str,
257
            default=EngineArgs.tokenizer,
258
259
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
260
261
262
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
263
            help='Skip initialization of tokenizer and detokenizer.')
Jasmond L's avatar
Jasmond L committed
264
265
        parser.add_argument(
            '--revision',
266
            type=nullable_str,
Jasmond L's avatar
Jasmond L committed
267
            default=None,
268
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
269
270
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
271
272
        parser.add_argument(
            '--code-revision',
273
            type=nullable_str,
274
            default=None,
275
            help='The specific revision to use for the model code on '
276
277
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
278
279
        parser.add_argument(
            '--tokenizer-revision',
280
            type=nullable_str,
281
            default=None,
282
283
284
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
285
286
287
288
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
289
            choices=['auto', 'slow', 'mistral'],
290
291
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
292
293
            'always use the slow tokenizer. \n* '
            '"mistral" will always use the `mistral_common` tokenizer.')
294
295
        parser.add_argument('--trust-remote-code',
                            action='store_true',
296
                            help='Trust remote code from huggingface.')
297
298
299
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
300
301
302
303
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
304
        parser.add_argument('--download-dir',
305
                            type=nullable_str,
Zhuohan Li's avatar
Zhuohan Li committed
306
                            default=EngineArgs.download_dir,
307
                            help='Directory to download and load the weights, '
308
                            'default to the default cache dir of '
309
                            'huggingface.')
310
311
312
313
        parser.add_argument(
            '--load-format',
            type=str,
            default=EngineArgs.load_format,
314
            choices=[f.value for f in LoadFormat],
315
316
            help='The format of the model weights to load.\n\n'
            '* "auto" will try to load the weights in the safetensors format '
317
            'and fall back to the pytorch bin format if safetensors format '
318
319
320
321
322
323
324
325
            'is not available.\n'
            '* "pt" will load the weights in the pytorch bin format.\n'
            '* "safetensors" will load the weights in the safetensors format.\n'
            '* "npcache" will load the weights in pytorch format and store '
            'a numpy cache to speed up the loading.\n'
            '* "dummy" will initialize the weights with random values, '
            'which is mainly for profiling.\n'
            '* "tensorizer" will load the weights using tensorizer from '
326
            'CoreWeave. See the Tensorize vLLM Model script in the Examples '
327
            'section for more information.\n'
328
329
            '* "runai_streamer" will load the Safetensors weights using Run:ai'
            'Model Streamer \n'
330
331
            '* "bitsandbytes" will load the weights using bitsandbytes '
            'quantization.\n')
332
333
334
335
336
337
338
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
339
340
341
342
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
343
344
345
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
346
347
348
349
350
351
352
353
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
354
355
356
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
357
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
358
            default=EngineArgs.kv_cache_dtype,
359
            help='Data type for kv cache storage. If "auto", will use model '
360
361
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
362
363
        parser.add_argument('--max-model-len',
                            type=int,
364
                            default=EngineArgs.max_model_len,
365
366
                            help='Model context length. If unspecified, will '
                            'be automatically derived from the model config.')
367
368
369
        parser.add_argument(
            '--guided-decoding-backend',
            type=str,
370
371
            default='xgrammar',
            choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
372
            help='Which engine will be used for guided decoding'
373
            ' (JSON schema / regex etc) by default. Currently support '
374
            'https://github.com/outlines-dev/outlines, '
375
            'https://github.com/mlc-ai/xgrammar, and '
376
377
378
            'https://github.com/noamgat/lm-format-enforcer.'
            ' Can be overridden per request via guided_decoding_backend'
            ' parameter.')
379
380
381
382
383
384
385
386
        parser.add_argument(
            '--logits-processor-pattern',
            type=nullable_str,
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
387
388
389
390
391
392
393
394
395
396
397
398
        parser.add_argument(
            '--model-impl',
            type=str,
            default=EngineArgs.model_impl,
            choices=[f.value for f in ModelImpl],
            help='Which implementation of the model to use.\n\n'
            '* "auto" will try to use the vLLM implementation if it exists '
            'and fall back to the Transformers implementation if no vLLM '
            'implementation is available.\n'
            '* "vllm" will use the vLLM model implementation.\n'
            '* "transformers" will use the Transformers model '
            'implementation.\n')
399
        # Parallel arguments
400
401
        parser.add_argument(
            '--distributed-executor-backend',
402
            choices=['ray', 'mp', 'uni', 'external_launcher'],
403
            default=EngineArgs.distributed_executor_backend,
404
405
406
407
408
409
            help='Backend to use for distributed model '
            'workers, either "ray" or "mp" (multiprocessing). If the product '
            'of pipeline_parallel_size and tensor_parallel_size is less than '
            'or equal to the number of GPUs available, "mp" will be used to '
            'keep processing on a single host. Otherwise, this will default '
            'to "ray" if Ray is installed and fail otherwise. Note that tpu '
410
            'only supports Ray for distributed inference.')
411

412
413
414
        parser.add_argument('--pipeline-parallel-size',
                            '-pp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
415
                            default=EngineArgs.pipeline_parallel_size,
416
                            help='Number of pipeline stages.')
417
418
419
        parser.add_argument('--tensor-parallel-size',
                            '-tp',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
420
                            default=EngineArgs.tensor_parallel_size,
421
                            help='Number of tensor parallel replicas.')
王敏's avatar
王敏 committed
422
423
424
425
        parser.add_argument('--moe-ep-size',
                            type=int,
                            default=EngineArgs.moe_ep_size,
                            help='Number of moe expert parallel replicas.')
426
427
428
        parser.add_argument(
            '--max-parallel-loading-workers',
            type=int,
429
            default=EngineArgs.max_parallel_loading_workers,
430
            help='Load model sequentially in multiple batches, '
431
            'to avoid RAM OOM when using tensor '
432
            'parallel and large models.')
433
434
435
        parser.add_argument(
            '--ray-workers-use-nsight',
            action='store_true',
436
            help='If specified, use nsight to profile Ray workers.')
437
        # KV cache arguments
438
439
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
440
                            default=EngineArgs.block_size,
441
                            choices=[8, 16, 32, 64, 128],
442
                            help='Token block size for contiguous chunks of '
443
                            'tokens. This is ignored on neuron devices and '
444
                            'set to ``--max-model-len``. On CUDA devices, '
445
446
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
447

448
449
450
451
452
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
453
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
454
        )
455
456
457
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
458
                            'capping to sliding window size.')
459
460
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
461
                            default=True,
462
463
464
465
466
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
467
468
469
470
471
472
473
474
        parser.add_argument(
            '--num-lookahead-slots',
            type=int,
            default=EngineArgs.num_lookahead_slots,
            help='Experimental scheduling config necessary for '
            'speculative decoding. This will be replaced by '
            'speculative config in the future; it is present '
            'to enable correctness tests until then.')
475

476
477
478
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
479
                            help='Random seed for operations.')
480
        parser.add_argument('--swap-space',
481
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
482
                            default=EngineArgs.swap_space,
483
                            help='CPU swap space size (GiB) per GPU.')
484
485
486
487
488
489
490
491
492
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
493
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
494
            'which requires at least 26GB GPU memory. Note that this '
495
            'requires fast CPU-GPU interconnect, as part of the model is '
496
497
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
498
499
500
501
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
502
503
504
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
505
506
507
508
509
510
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
511
        parser.add_argument(
512
            '--num-gpu-blocks-override',
513
514
515
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
516
            ' of GPU blocks. Used for testing preemption.')
517
518
        parser.add_argument('--max-num-batched-tokens',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
519
                            default=EngineArgs.max_num_batched_tokens,
520
521
                            help='Maximum number of batched tokens per '
                            'iteration.')
522
523
        parser.add_argument('--max-num-seqs',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
524
                            default=EngineArgs.max_num_seqs,
525
                            help='Maximum number of sequences per iteration.')
526
527
528
529
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
530
531
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
532
533
        parser.add_argument('--disable-log-stats',
                            action='store_true',
534
                            help='Disable logging statistics.')
535
536
537
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
538
                            type=nullable_str,
539
                            choices=[*QUANTIZATION_METHODS, None],
540
                            default=EngineArgs.quantization,
541
542
543
544
545
546
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
547
548
549
550
551
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
552
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
553
554
555
556
557
558
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
559
560
561
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
562
                            help='Extra arguments for the HuggingFace config. '
563
564
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
565
566
567
568
569
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
570
        parser.add_argument('--max-seq-len-to-capture',
571
572
573
574
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
575
576
577
578
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
579
580
581
        parser.add_argument('--disable-custom-all-reduce',
                            action='store_true',
                            default=EngineArgs.disable_custom_all_reduce,
582
                            help='See ParallelConfig.')
583
584
585
586
587
588
589
590
591
592
593
594
595
        parser.add_argument('--tokenizer-pool-size',
                            type=int,
                            default=EngineArgs.tokenizer_pool_size,
                            help='Size of tokenizer pool to use for '
                            'asynchronous tokenization. If 0, will '
                            'use synchronous tokenization.')
        parser.add_argument('--tokenizer-pool-type',
                            type=str,
                            default=EngineArgs.tokenizer_pool_type,
                            help='Type of tokenizer pool to use for '
                            'asynchronous tokenization. Ignored '
                            'if tokenizer_pool_size is 0.')
        parser.add_argument('--tokenizer-pool-extra-config',
596
                            type=nullable_str,
597
598
599
600
601
                            default=EngineArgs.tokenizer_pool_extra_config,
                            help='Extra config for tokenizer pool. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary. Ignored if '
                            'tokenizer_pool_size is 0.')
602
603
604
605
606
607
608
609
610
611
612
613
614
615

        # Multimodal related configs
        parser.add_argument(
            '--limit-mm-per-prompt',
            type=nullable_kvs,
            default=EngineArgs.limit_mm_per_prompt,
            # The default value is given in
            # MultiModalRegistry.init_mm_limits_per_prompt
            help=('For each multimodal plugin, limit how many '
                  'input instances to allow for each prompt. '
                  'Expects a comma-separated list of items, '
                  'e.g.: `image=16,video=2` allows a maximum of 16 '
                  'images and 2 videos per prompt. Defaults to 1 for '
                  'each modality.'))
616
617
618
619
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
620
            help=('Overrides for the multimodal input mapping/processing, '
621
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
622
        parser.add_argument(
623
            '--disable-mm-preprocessor-cache',
624
            action='store_true',
625
626
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
627

628
629
630
631
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
632
633
634
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
635
636
637
638
639
640
641
642
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
643
644
645
646
647
648
649
650
        parser.add_argument('--merge-lora',
                            type=bool,
                            default=False,
                            help='If set to True, the weights of the base layer will be merged with the weights of Lora.')
        parser.add_argument('--lora-target-modules',
                            nargs='*',
                            default=None,
                            help='List of lora module name, If not specified, modules will be chosen according to the model architecture.')
651
652
653
654
655
656
657
658
659
660
661
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
662
            choices=['auto', 'float16', 'bfloat16'],
663
664
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
665
666
667
668
669
670
671
672
673
674
675
        parser.add_argument(
            '--long-lora-scaling-factors',
            type=nullable_str,
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
676
677
678
679
680
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
681
682
                  'Must be >= than max_loras. '
                  'Defaults to max_loras.'))
683
684
685
686
687
688
689
690
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
691
692
693
694
695
696
697
698
699
700
701
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
702
703
704
        parser.add_argument("--device",
                            type=str,
                            default=EngineArgs.device,
705
                            choices=DEVICE_OPTIONS,
706
                            help='Device type for vLLM execution.')
707
708
709
710
711
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
712

713
714
        parser.add_argument(
            '--multi-step-stream-outputs',
715
716
717
718
719
720
            action=StoreBoolean,
            default=EngineArgs.multi_step_stream_outputs,
            nargs="?",
            const="True",
            help='If False, then multi-step will stream outputs at the end '
            'of all steps')
721
722
723
724
        parser.add_argument(
            '--scheduler-delay-factor',
            type=float,
            default=EngineArgs.scheduler_delay_factor,
725
            help='Apply a delay (of delay factor multiplied by previous '
726
            'prompt latency) before scheduling next prompt.')
727
728
        parser.add_argument(
            '--enable-chunked-prefill',
729
730
731
732
            action=StoreBoolean,
            default=EngineArgs.enable_chunked_prefill,
            nargs="?",
            const="True",
733
            help='If set, the prefill requests can be chunked based on the '
734
            'max_num_batched_tokens.')
735
736
737

        parser.add_argument(
            '--speculative-model',
738
            type=nullable_str,
739
            default=EngineArgs.speculative_model,
740
741
            help=
            'The name of the draft model to be used in speculative decoding.')
742
743
744
745
746
747
        # Quantization settings for speculative model.
        parser.add_argument(
            '--speculative-model-quantization',
            type=nullable_str,
            choices=[*QUANTIZATION_METHODS, None],
            default=EngineArgs.speculative_model_quantization,
748
            help='Method used to quantize the weights of speculative model. '
749
750
751
752
753
            'If None, we first check the `quantization_config` '
            'attribute in the model config file. If that is '
            'None, we assume the model weights are not '
            'quantized and use `dtype` to determine the data '
            'type of the weights.')
754
755
756
        parser.add_argument(
            '--num-speculative-tokens',
            type=int,
757
            default=EngineArgs.num_speculative_tokens,
758
            help='The number of speculative tokens to sample from '
759
            'the draft model in speculative decoding.')
760
761
762
763
764
765
        parser.add_argument(
            '--num-speculative-heads',
            type=int,
            default=EngineArgs.num_speculative_heads,
            help='The number of speculative heads to sample from '
                 'the draft model in speculative decoding.')
766
767
768
769
770
771
        parser.add_argument(
            '--speculative-disable-mqa-scorer',
            action='store_true',
            help=
            'If set to True, the MQA scorer will be disabled in speculative '
            ' and fall back to batch expansion')
772
773
774
775
776
777
778
        parser.add_argument(
            '--speculative-draft-tensor-parallel-size',
            '-spec-draft-tp',
            type=int,
            default=EngineArgs.speculative_draft_tensor_parallel_size,
            help='Number of tensor parallel replicas for '
            'the draft model in speculative decoding.')
779

780
781
        parser.add_argument(
            '--speculative-max-model-len',
782
            type=int,
783
784
785
786
787
            default=EngineArgs.speculative_max_model_len,
            help='The maximum sequence length supported by the '
            'draft model. Sequences over this length will skip '
            'speculation.')

788
789
790
791
792
793
794
        parser.add_argument(
            '--speculative-disable-by-batch-size',
            type=int,
            default=EngineArgs.speculative_disable_by_batch_size,
            help='Disable speculative decoding for new incoming requests '
            'if the number of enqueue requests is larger than this value.')

795
796
797
798
799
800
801
802
803
804
805
806
807
808
        parser.add_argument(
            '--ngram-prompt-lookup-max',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_max,
            help='Max size of window for ngram prompt lookup in speculative '
            'decoding.')

        parser.add_argument(
            '--ngram-prompt-lookup-min',
            type=int,
            default=EngineArgs.ngram_prompt_lookup_min,
            help='Min size of window for ngram prompt lookup in speculative '
            'decoding.')

809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
        parser.add_argument(
            '--spec-decoding-acceptance-method',
            type=str,
            default=EngineArgs.spec_decoding_acceptance_method,
            choices=['rejection_sampler', 'typical_acceptance_sampler'],
            help='Specify the acceptance method to use during draft token '
            'verification in speculative decoding. Two types of acceptance '
            'routines are supported: '
            '1) RejectionSampler which does not allow changing the '
            'acceptance rate of draft tokens, '
            '2) TypicalAcceptanceSampler which is configurable, allowing for '
            'a higher acceptance rate at the cost of lower quality, '
            'and vice versa.')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-threshold',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
            help='Set the lower bound threshold for the posterior '
            'probability of a token to be accepted. This threshold is '
            'used by the TypicalAcceptanceSampler to make sampling decisions '
            'during speculative decoding. Defaults to 0.09')

        parser.add_argument(
            '--typical-acceptance-sampler-posterior-alpha',
            type=float,
            default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
            help='A scaling factor for the entropy-based threshold for token '
            'acceptance in the TypicalAcceptanceSampler. Typically defaults '
            'to sqrt of --typical-acceptance-sampler-posterior-threshold '
            'i.e. 0.3')

841
842
        parser.add_argument(
            '--disable-logprobs-during-spec-decoding',
843
            action=StoreBoolean,
844
            default=EngineArgs.disable_logprobs_during_spec_decoding,
845
846
            nargs="?",
            const="True",
847
848
849
850
851
852
853
854
            help='If set to True, token log probabilities are not returned '
            'during speculative decoding. If set to False, log probabilities '
            'are returned according to the settings in SamplingParams. If '
            'not specified, it defaults to True. Disabling log probabilities '
            'during speculative decoding reduces latency by skipping logprob '
            'calculation in proposal sampling, target sampling, and after '
            'accepted tokens are determined.')

855
        parser.add_argument('--model-loader-extra-config',
856
                            type=nullable_str,
857
858
859
860
861
862
                            default=EngineArgs.model_loader_extra_config,
                            help='Extra config for model loader. '
                            'This will be passed to the model loader '
                            'corresponding to the chosen load_format. '
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
863
864
865
866
867
868
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
869
            "Default to `original/**/*` to avoid repeated loading of llama's "
870
            "checkpoints.")
871
        parser.add_argument(
872
            '--preemption-mode',
873
874
            type=str,
            default=None,
875
876
877
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
878

879
880
881
882
883
884
885
886
887
888
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
889
            "same as the ``--model`` argument. Noted that this name(s) "
890
            "will also be used in `model_name` tag content of "
891
            "prometheus metrics, if multiple names provided, metrics "
892
            "tag will take the first one.")
893
894
895
896
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
897
898
899
900
901
902

        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
903
904
905
906
907
908
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
909
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
910
911
912
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
913

914
915
916
917
918
919
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
920

921
922
923
924
925
926
927
928
929
930
        parser.add_argument(
            '--scheduling-policy',
            choices=['fcfs', 'priority'],
            default="fcfs",
            help='The scheduling policy to use. "fcfs" (first come first served'
            ', i.e. requests are handled in order of arrival; default) '
            'or "priority" (requests are handled based on given '
            'priority (lower value means earlier handling) and time of '
            'arrival deciding any ties).')

931
932
        parser.add_argument(
            '--override-neuron-config',
933
            type=json.loads,
934
            default=None,
935
            help="Override or set neuron device configuration. "
936
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
937
        parser.add_argument(
938
939
            '--override-pooler-config',
            type=PoolerConfig.from_json,
940
            default=None,
941
            help="Override or set the pooling method for pooling models. "
942
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
943

944
945
946
947
948
949
950
951
952
953
954
955
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
956
957
958
959
                            'use a JSON string.\n'
                            'Following the convention of traditional '
                            'compilers, using -O without space is also '
                            'supported. -O3 is equivalent to -O 3.')
960

961
962
963
964
965
966
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

967
968
969
970
971
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
972
973
974
975
976
        parser.add_argument(
            "--generation-config",
            type=nullable_str,
            default=None,
            help="The folder path to the generation config. "
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
            "Defaults to None, no generation config is loaded, vLLM defaults "
            "will be used. If set to 'auto', the generation config will be "
            "loaded from model path. If set to a folder path, the generation "
            "config will be loaded from the specified folder path. If "
            "`max_new_tokens` is specified in generation config, then "
            "it sets a server-wide limit on the number of output tokens "
            "for all requests.")

        parser.add_argument(
            "--override-generation-config",
            type=json.loads,
            default=None,
            help="Overrides or sets generation config in JSON format. "
            "e.g. ``{\"temperature\": 0.5}``. If used with "
            "--generation-config=auto, the override parameters will be merged "
            "with the default config from the model. If generation-config is "
            "None, only the override parameters are used.")
994

995
996
997
998
999
1000
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

1001
1002
1003
1004
1005
1006
1007
1008
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')
1009

1010
        return parser
1011
1012

    @classmethod
1013
    def from_cli_args(cls, args: argparse.Namespace):
1014
1015
1016
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
1017
1018
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
1019

1020
1021
    def create_model_config(self) -> ModelConfig:
        return ModelConfig(
1022
            model=self.model,
1023
            task=self.task,
1024
1025
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
1026
1027
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1028
            allowed_local_media_path=self.allowed_local_media_path,
1029
1030
1031
1032
1033
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1034
            rope_theta=self.rope_theta,
1035
            hf_overrides=self.hf_overrides,
1036
1037
1038
1039
1040
1041
1042
1043
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
            skip_tokenizer_init=self.skip_tokenizer_init,
1044
            served_model_name=self.served_model_name,
1045
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1046
            use_async_output_proc=not self.disable_async_output_proc,
1047
            config_format=self.config_format,
1048
            mm_processor_kwargs=self.mm_processor_kwargs,
1049
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1050
1051
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1052
            logits_processor_pattern=self.logits_processor_pattern,
1053
            generation_config=self.generation_config,
1054
            override_generation_config=self.override_generation_config,
1055
            enable_sleep_mode=self.enable_sleep_mode,
1056
            model_impl=self.model_impl,
1057
        )
1058

1059
1060
1061
1062
1063
1064
1065
1066
    def create_load_config(self) -> LoadConfig:
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
        )

1067
1068
1069
1070
1071
1072
    def create_engine_config(self,
                             usage_context: Optional[UsageContext] = None
                             ) -> VllmConfig:
        if envs.VLLM_USE_V1:
            self._override_v1_engine_args(usage_context)

1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # bitsandbytes quantization needs a specific model loader
        # so we make sure the quant method and the load format are consistent
        if (self.quantization == "bitsandbytes" or
           self.qlora_adapter_name_or_path is not None) and \
           self.load_format != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes quantization and QLoRA adapter only support "
                f"'bitsandbytes' load format, but got {self.load_format}")

        if (self.load_format == "bitsandbytes" or
            self.qlora_adapter_name_or_path is not None) and \
            self.quantization != "bitsandbytes":
            raise ValueError(
                "BitsAndBytes load format and QLoRA adapter only support "
                f"'bitsandbytes' quantization, but got {self.quantization}")

        assert self.cpu_offload_gb >= 0, (
            "CPU offload space must be non-negative"
            f", but got {self.cpu_offload_gb}")

        device_config = DeviceConfig(device=self.device)
        model_config = self.create_model_config()

1100
1101
1102
1103
1104
        if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
                and self.enable_prefix_caching):
            logger.warning("--enable-prefix-caching is currently not "
                           "supported for multimodal models in v0 and "
                           "has been disabled.")
1105
1106
            self.enable_prefix_caching = False

1107
        cache_config = CacheConfig(
1108
            block_size=self.block_size,
1109
1110
1111
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1112
            is_attention_free=model_config.is_attention_free,
1113
1114
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1115
1116
            enable_prefix_caching=self.enable_prefix_caching,
            cpu_offload_gb=self.cpu_offload_gb,
1117
            calculate_kv_scales=self.calculate_kv_scales,
1118
        )
1119
        parallel_config = ParallelConfig(
1120
1121
1122
1123
1124
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1125
1126
1127
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1128
            ),
1129
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1130
1131
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
王敏's avatar
王敏 committed
1132
            moe_ep_size=self.moe_ep_size,
1133
        )
1134

1135
1136
1137
1138
1139
1140
        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # If not explicitly set, enable chunked prefill by default for
            # long context (> 32K) models. This is to avoid OOM errors in the
            # initial memory profiling phase.
1141

1142
1143
1144
1145
1146
1147
            # For multimodal models, chunked prefill is disabled by default in
            # V0, but enabled by design in V1
            if model_config.is_multimodal_model:
                self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)

            elif use_long_context:
1148
1149
1150
1151
                is_gpu = device_config.device_type == "cuda"
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
                use_spec_decode = self.speculative_model is not None
1152
                from vllm.platforms import current_platform
1153
1154
1155
                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
1156
1157
                        and model_config.runner_type != "pooling"
                        and not current_platform.is_rocm()):
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models with "
                        "max_model_len > 32K. Currently, chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable chunked prefill "
                        "by setting --enable-chunked-prefill=False.")
            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause OOM "
                "errors during the initial memory profiling phase, or result "
                "in low performance due to small KV cache space. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
1174
1175
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
1176
            msg = "Chunked prefill is not supported for pooling models"
1177
            raise ValueError(msg)
1178

1179

1180
1181
1182
1183
1184
        speculative_config = SpeculativeConfig.maybe_create_spec_config(
            target_model_config=model_config,
            target_parallel_config=parallel_config,
            target_dtype=self.dtype,
            speculative_model=self.speculative_model,
1185
1186
            speculative_model_quantization = \
                self.speculative_model_quantization,
1187
1188
            speculative_draft_tensor_parallel_size = \
                self.speculative_draft_tensor_parallel_size,
1189
            num_speculative_tokens=self.num_speculative_tokens,
1190
            speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
1191
1192
            speculative_disable_by_batch_size=self.
            speculative_disable_by_batch_size,
1193
1194
            speculative_max_model_len=self.speculative_max_model_len,
            enable_chunked_prefill=self.enable_chunked_prefill,
1195
            disable_log_stats=self.disable_log_stats,
1196
1197
            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
1198
1199
1200
1201
1202
1203
            draft_token_acceptance_method=\
                self.spec_decoding_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=self.
            typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=self.
            typical_acceptance_sampler_posterior_alpha,
1204
            disable_logprobs=self.disable_logprobs_during_spec_decoding,
1205
            num_speculative_heads=self.num_speculative_heads
1206
1207
        )

1208
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1209
        # If the feature combo become valid
1210
1211
1212
1213
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1214
1215
1216
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1217
1218
1219
1220
1221
1222
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1223
1224
1225
1226
1227
1228
1229
1230
1231

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
        if not self.use_v2_block_manager:
            logger.warning(
                "[DEPRECATED] Block manager v1 has been removed, "
                "and setting --use-v2-block-manager to True or False has "
                "no effect on vLLM behavior. Please remove "
                "--use-v2-block-manager in your engine argument. "
                "If your use case is not supported by "
                "SelfAttnBlockSpaceManager (i.e. block manager v2),"
                " please file an issue with detailed information.")

1242
        scheduler_config = SchedulerConfig(
1243
            runner_type=model_config.runner_type,
1244
1245
1246
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1247
            num_lookahead_slots=num_lookahead_slots,
1248
1249
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1250
            is_multimodal_model=model_config.is_multimodal_model,
1251
            preemption_mode=self.preemption_mode,
1252
            num_scheduler_steps=self.num_scheduler_steps,
1253
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1254
1255
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1256
            policy=self.scheduling_policy)
1257
        lora_config = LoRAConfig(
1258
            bias_enabled=self.enable_lora_bias,
1259
1260
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1261
            fully_sharded_loras=self.fully_sharded_loras,
1262
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1263
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1264
1265
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
1266
1267
1268
            and self.max_cpu_loras > 0 else None,
            merge_lora=self.merge_lora,
            lora_target_modules=self.lora_target_modules) if self.enable_lora else None
1269

1270
1271
1272
1273
1274
1275
1276
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            if self.model_loader_extra_config is None:
                self.model_loader_extra_config = {}
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1277
        load_config = self.create_load_config()
1278

1279
1280
1281
1282
1283
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1284
1285
1286
        decoding_config = DecodingConfig(
            guided_decoding_backend=self.guided_decoding_backend)

1287
1288
1289
1290
1291
1292
1293
1294
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1295
        observability_config = ObservabilityConfig(
1296
1297
1298
1299
1300
1301
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1302

1303
        config = VllmConfig(
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1314
            prompt_adapter_config=prompt_adapter_config,
1315
            compilation_config=self.compilation_config,
1316
            kv_transfer_config=self.kv_transfer_config,
1317
        )
1318

1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
        if envs.VLLM_USE_V1:
            self._override_v1_engine_config(config)
        return config

    def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
        """
        Override the EngineArgs's args based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

1329
1330
1331
1332
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
        # When no user override, set the default values based on the usage
        # context.
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
        # Use different default values for different hardware.
        from vllm.platforms import current_platform
        device_name = current_platform.get_device_name().lower()
        if "h100" in device_name or "h200" in device_name:
            # For H100 and H200, we use larger default values.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }

1349
1350
1351
1352
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1353
1354
1355
            logger.warning(
                "Setting max_num_batched_tokens to %d for %s usage context.",
                self.max_num_batched_tokens, usage_context.value)
1356
1357
1358
1359
1360
1361
1362

    def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
        """
        Override the EngineConfig's configs based on the usage context for V1.
        """
        assert envs.VLLM_USE_V1, "V1 is not enabled"

1363

1364
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1365
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1366
    """Arguments for asynchronous vLLM engine."""
1367
    disable_log_requests: bool = False
1368
1369

    @staticmethod
1370
1371
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1372
1373
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1374
1375
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1376
                            help='Disable logging requests.')
1377
        return parser
1378
1379
1380
1381


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1382
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1383
1384
1385


def _async_engine_args_parser():
1386
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1387
                                        async_args_only=True)