arg_utils.py 80.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
# yapf: disable
4
import argparse
5
import dataclasses
6
import json
7
import re
8
import threading
9
from dataclasses import MISSING, dataclass, fields
10
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal,
11
12
                    Optional, Tuple, Type, TypeVar, Union, cast, get_args,
                    get_origin)
13

14
import torch
15
from typing_extensions import TypeIs
16

17
import vllm.envs as envs
18
from vllm import version
19
20
from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
                         DecodingConfig, Device, DeviceConfig,
21
                         DistributedExecutorBackend, HfOverrides,
22
                         KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
23
24
25
26
27
28
                         ModelConfig, ModelImpl, MultiModalConfig,
                         ObservabilityConfig, ParallelConfig, PoolerConfig,
                         PoolType, PromptAdapterConfig, SchedulerConfig,
                         SchedulerPolicy, SpeculativeConfig, TaskOption,
                         TokenizerPoolConfig, VllmConfig, get_attr_docs,
                         get_field)
29
from vllm.executor.executor_base import ExecutorBase
30
from vllm.logger import init_logger
31
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
32
from vllm.plugins import load_general_plugins
33
from vllm.reasoning import ReasoningParserManager
34
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
35
from vllm.transformers_utils.utils import check_gguf_file
36
from vllm.usage.usage_lib import UsageContext
37
38
39
from vllm.utils import FlexibleArgumentParser, is_in_ray_actor

# yapf: enable
40

41
if TYPE_CHECKING:
42
    from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
43

44
45
logger = init_logger(__name__)

46
47
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

48
49
50
51
52
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]

53

54
def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]:
55
    if val == "" or val == "None":
56
        return None
57
    try:
58
        return return_type(val)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    except ValueError as e:
        raise argparse.ArgumentTypeError(
            f"Value {val} cannot be converted to {return_type}.") from e


def optional_str(val: str) -> Optional[str]:
    return optional_arg(val, str)


def optional_int(val: str) -> Optional[int]:
    return optional_arg(val, int)


def optional_float(val: str) -> Optional[float]:
    return optional_arg(val, float)
74
75


76
77
78
79
80
def nullable_kvs(val: str) -> Optional[dict[str, int]]:
    """NOTE: This function is deprecated, args should be passed as JSON
    strings instead.
    
    Parses a string containing comma separate key [str] to value [int]
81
82
83
84
85
86
87
88
    pairs into a dictionary.

    Args:
        val: String value to be parsed.

    Returns:
        Dictionary with parsed values.
    """
89
90
91
92
93
    if len(val) == 0:
        return None

    out_dict: Dict[str, int] = {}
    for item in val.split(","):
94
95
96
97
98
        kv_parts = [part.lower().strip() for part in item.split("=")]
        if len(kv_parts) != 2:
            raise argparse.ArgumentTypeError(
                "Each item should be in the form KEY=VALUE")
        key, value = kv_parts
99
100

        try:
101
            parsed_value = int(value)
102
103
        except ValueError as exc:
            msg = f"Failed to parse value of item {key}={value}"
104
105
106
107
108
109
            raise argparse.ArgumentTypeError(msg) from exc

        if key in out_dict and out_dict[key] != parsed_value:
            raise argparse.ArgumentTypeError(
                f"Conflicting values specified for key: {key}")
        out_dict[key] = parsed_value
110
111
112
113

    return out_dict


114
def optional_dict(val: str) -> Optional[dict[str, int]]:
115
    if re.match("^{.*}$", val):
116
        return optional_arg(val, json.loads)
117
118
119
120
121
122

    logger.warning(
        "Failed to parse JSON string. Attempting to parse as "
        "comma-separated key=value pairs. This will be deprecated in a "
        "future release.")
    return nullable_kvs(val)
123
124


125
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
126
class EngineArgs:
Woosuk Kwon's avatar
Woosuk Kwon committed
127
    """Arguments for vLLM engine."""
128
    model: str = 'facebook/opt-125m'
129
    served_model_name: Optional[Union[str, List[str]]] = None
130
    tokenizer: Optional[str] = None
131
    hf_config_path: Optional[str] = None
132
    task: TaskOption = "auto"
133
    skip_tokenizer_init: bool = False
134
    tokenizer_mode: str = 'auto'
135
    trust_remote_code: bool = False
136
    allowed_local_media_path: str = ""
137
138
    download_dir: Optional[str] = LoadConfig.download_dir
    load_format: str = LoadConfig.load_format
139
    config_format: ConfigFormat = ConfigFormat.AUTO
140
    dtype: str = 'auto'
141
    kv_cache_dtype: str = 'auto'
142
    seed: Optional[int] = None
143
    max_model_len: Optional[int] = None
144
145
146
    # Note: Specifying a custom executor backend by passing a class
    # is intended for expert use only. The API may change without
    # notice.
147
    distributed_executor_backend: Optional[Union[
148
149
        DistributedExecutorBackend,
        Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
150
    # number of P/D disaggregation (or other disaggregation) workers
151
152
153
154
155
156
    pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
    tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
    data_parallel_size: int = ParallelConfig.data_parallel_size
    enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
    max_parallel_loading_workers: Optional[
        int] = ParallelConfig.max_parallel_loading_workers
157
    block_size: Optional[int] = None
158
    enable_prefix_caching: Optional[bool] = None
159
    prefix_caching_hash_algo: str = "builtin"
160
    disable_sliding_window: bool = False
161
    disable_cascade_attn: bool = False
162
    use_v2_block_manager: bool = True
163
164
    swap_space: float = 4  # GiB
    cpu_offload_gb: float = 0  # GiB
165
    gpu_memory_utilization: float = 0.90
166
167
168
169
170
171
172
    max_num_batched_tokens: Optional[
        int] = SchedulerConfig.max_num_batched_tokens
    max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
    max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
    long_prefill_token_threshold: int = \
        SchedulerConfig.long_prefill_token_threshold
    max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
173
    max_logprobs: int = 20  # Default value for OpenAI Chat Completions API
174
    disable_log_stats: bool = False
Jasmond L's avatar
Jasmond L committed
175
    revision: Optional[str] = None
176
    code_revision: Optional[str] = None
177
    rope_scaling: Optional[Dict[str, Any]] = None
178
    rope_theta: Optional[float] = None
179
    hf_token: Optional[Union[bool, str]] = None
180
    hf_overrides: Optional[HfOverrides] = None
181
    tokenizer_revision: Optional[str] = None
182
    quantization: Optional[str] = None
183
    enforce_eager: Optional[bool] = None
184
    max_seq_len_to_capture: int = 8192
185
    disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
186
    tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
187
188
189
    # Note: Specifying a tokenizer pool by passing a class
    # is intended for expert use only. The API may change without
    # notice.
190
191
192
193
    tokenizer_pool_type: Union[PoolType, Type["BaseTokenizerGroup"]] = \
        TokenizerPoolConfig.pool_type
    tokenizer_pool_extra_config: dict[str, Any] = \
        get_field(TokenizerPoolConfig, "extra_config")
194
    limit_mm_per_prompt: dict[str, int] = \
195
        get_field(MultiModalConfig, "limit_per_prompt")
196
    mm_processor_kwargs: Optional[Dict[str, Any]] = None
197
    disable_mm_preprocessor_cache: bool = False
198
    enable_lora: bool = False
199
    enable_lora_bias: bool = False
200
201
    max_loras: int = 1
    max_lora_rank: int = 16
202
203
204
    enable_prompt_adapter: bool = False
    max_prompt_adapters: int = 1
    max_prompt_adapter_token: int = 0
205
    fully_sharded_loras: bool = False
206
    lora_extra_vocab_size: int = 256
207
    long_lora_scaling_factors: Optional[Tuple[float]] = None
208
    lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
209
    max_cpu_loras: Optional[int] = None
210
    device: Device = DeviceConfig.device
211
212
    num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
    multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
213
    ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
214
    num_gpu_blocks_override: Optional[int] = None
215
    num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
216
217
    model_loader_extra_config: dict = \
        get_field(LoadConfig, "model_loader_extra_config")
218
219
    ignore_patterns: Optional[Union[str,
                                    List[str]]] = LoadConfig.ignore_patterns
220
    preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
221

222
223
224
225
    scheduler_delay_factor: float = SchedulerConfig.delay_factor
    enable_chunked_prefill: Optional[
        bool] = SchedulerConfig.enable_chunked_prefill
    disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
226

227
    guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
228
    logits_processor_pattern: Optional[str] = None
229

230
    speculative_config: Optional[Dict[str, Any]] = None
231

232
    qlora_adapter_name_or_path: Optional[str] = None
233
    show_hidden_metrics_for_version: Optional[str] = None
234
    otlp_traces_endpoint: Optional[str] = None
235
    collect_detailed_traces: Optional[str] = None
236
    disable_async_output_proc: bool = False
237
238
    scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
    scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
239

240
241
    override_neuron_config: Optional[Dict[str, Any]] = None
    override_pooler_config: Optional[PoolerConfig] = None
242
    compilation_config: Optional[CompilationConfig] = None
243
244
    worker_cls: str = ParallelConfig.worker_cls
    worker_extension_cls: str = ParallelConfig.worker_extension_cls
245

246
247
    kv_transfer_config: Optional[KVTransferConfig] = None

248
    generation_config: Optional[str] = "auto"
249
    override_generation_config: Optional[Dict[str, Any]] = None
250
    enable_sleep_mode: bool = False
251
    model_impl: str = "auto"
252

253
254
    calculate_kv_scales: Optional[bool] = None

255
    additional_config: Optional[Dict[str, Any]] = None
256
    enable_reasoning: Optional[bool] = None
257
    reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend
258
    use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
259

260
    def __post_init__(self):
261
        if not self.tokenizer:
262
            self.tokenizer = self.model
263

264
265
266
        # support `EngineArgs(compilation_config={...})`
        # without having to manually construct a
        # CompilationConfig object
267
        if isinstance(self.compilation_config, (int, dict)):
268
269
            self.compilation_config = CompilationConfig.from_cli(
                str(self.compilation_config))
270

271
        # Setup plugins
272
273
        from vllm.plugins import load_general_plugins
        load_general_plugins()
274
275

    @staticmethod
276
    def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
Woosuk Kwon's avatar
Woosuk Kwon committed
277
        """Shared CLI arguments for vLLM engine."""
278

279
        def is_type_in_union(cls: TypeHint, type: TypeHint) -> bool:
280
            """Check if the class is a type in a union type."""
281
282
283
284
285
286
287
288
289
290
291
292
            is_union = get_origin(cls) is Union
            type_in_union = type in [get_origin(a) or a for a in get_args(cls)]
            return is_union and type_in_union

        def get_type_from_union(cls: TypeHint, type: TypeHintT) -> TypeHintT:
            """Get the type in a union type."""
            for arg in get_args(cls):
                if (get_origin(arg) or arg) is type:
                    return arg
            raise ValueError(f"Type {type} not found in union type {cls}.")

        def is_optional(cls: TypeHint) -> TypeIs[Union[Any, None]]:
293
            """Check if the class is an optional type."""
294
            return is_type_in_union(cls, type(None))
295

296
297
298
299
300
301
302
303
304
        def can_be_type(cls: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
            """Check if the class can be of type."""
            return cls is type or get_origin(cls) is type or is_type_in_union(
                cls, type)

        def is_custom_type(cls: TypeHint) -> bool:
            """Check if the class is a custom type."""
            return cls.__module__ != "builtins"

305
        def get_kwargs(cls: type[Config]) -> dict[str, Any]:
306
307
308
309
            cls_docs = get_attr_docs(cls)
            kwargs = {}
            for field in fields(cls):
                name = field.name
310
311
312
313
                default = field.default
                # This will only be True if default is MISSING
                if field.default_factory is not MISSING:
                    default = field.default_factory()
314
                kwargs[name] = {"default": default, "help": cls_docs[name]}
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342

                # Make note of if the field is optional and get the actual
                # type of the field if it is
                optional = is_optional(field.type)
                field_type = get_args(
                    field.type)[0] if optional else field.type

                if can_be_type(field_type, bool):
                    # Creates --no-<name> and --<name> flags
                    kwargs[name]["action"] = argparse.BooleanOptionalAction
                    kwargs[name]["type"] = bool
                elif can_be_type(field_type, Literal):
                    # Creates choices from Literal arguments
                    if is_type_in_union(field_type, Literal):
                        field_type = get_type_from_union(field_type, Literal)
                    choices = get_args(field_type)
                    kwargs[name]["choices"] = choices
                    choice_type = type(choices[0])
                    assert all(type(c) is choice_type for c in choices), (
                        f"All choices must be of the same type. "
                        f"Got {choices} with types {[type(c) for c in choices]}"
                    )
                    kwargs[name]["type"] = choice_type
                elif can_be_type(field_type, int):
                    kwargs[name]["type"] = optional_int if optional else int
                elif can_be_type(field_type, float):
                    kwargs[name][
                        "type"] = optional_float if optional else float
343
344
                elif can_be_type(field_type, dict):
                    kwargs[name]["type"] = optional_dict
345
346
347
348
349
350
                elif (can_be_type(field_type, str)
                      or is_custom_type(field_type)):
                    kwargs[name]["type"] = optional_str if optional else str
                else:
                    raise ValueError(
                        f"Unsupported type {field.type} for argument {name}. ")
351
352
            return kwargs

353
        # Model arguments
354
355
356
        parser.add_argument(
            '--model',
            type=str,
357
            default=EngineArgs.model,
358
            help='Name or path of the huggingface model to use.')
359
360
361
362
363
364
        parser.add_argument(
            '--task',
            default=EngineArgs.task,
            choices=get_args(TaskOption),
            help='The task to use the model for. Each vLLM instance only '
            'supports one task, even if the same model can be used for '
365
            'multiple tasks. When the model only supports one task, ``"auto"`` '
366
367
            'can be used to select it; otherwise, you must specify explicitly '
            'which task to use.')
368
369
        parser.add_argument(
            '--tokenizer',
370
            type=optional_str,
371
            default=EngineArgs.tokenizer,
372
373
            help='Name or path of the huggingface tokenizer to use. '
            'If unspecified, model name or path will be used.')
374
375
        parser.add_argument(
            "--hf-config-path",
376
            type=optional_str,
377
378
379
            default=EngineArgs.hf_config_path,
            help='Name or path of the huggingface config to use. '
            'If unspecified, model name or path will be used.')
380
381
382
        parser.add_argument(
            '--skip-tokenizer-init',
            action='store_true',
383
384
385
            help='Skip initialization of tokenizer and detokenizer. '
            'Expects valid prompt_token_ids and None for prompt from '
            'the input. The generated output will contain token ids.')
Jasmond L's avatar
Jasmond L committed
386
387
        parser.add_argument(
            '--revision',
388
            type=optional_str,
Jasmond L's avatar
Jasmond L committed
389
            default=None,
390
            help='The specific model version to use. It can be a branch '
Jasmond L's avatar
Jasmond L committed
391
392
            'name, a tag name, or a commit id. If unspecified, will use '
            'the default version.')
393
394
        parser.add_argument(
            '--code-revision',
395
            type=optional_str,
396
            default=None,
397
            help='The specific revision to use for the model code on '
398
399
            'Hugging Face Hub. It can be a branch name, a tag name, or a '
            'commit id. If unspecified, will use the default version.')
400
401
        parser.add_argument(
            '--tokenizer-revision',
402
            type=optional_str,
403
            default=None,
404
405
406
            help='Revision of the huggingface tokenizer to use. '
            'It can be a branch name, a tag name, or a commit id. '
            'If unspecified, will use the default version.')
407
408
409
410
        parser.add_argument(
            '--tokenizer-mode',
            type=str,
            default=EngineArgs.tokenizer_mode,
411
            choices=['auto', 'slow', 'mistral', 'custom'],
412
413
            help='The tokenizer mode.\n\n* "auto" will use the '
            'fast tokenizer if available.\n* "slow" will '
414
            'always use the slow tokenizer. \n* '
415
416
417
            '"mistral" will always use the `mistral_common` tokenizer. \n* '
            '"custom" will use --tokenizer to select the '
            'preregistered tokenizer.')
418
419
        parser.add_argument('--trust-remote-code',
                            action='store_true',
420
                            help='Trust remote code from huggingface.')
421
422
423
        parser.add_argument(
            '--allowed-local-media-path',
            type=str,
424
425
426
427
            help="Allowing API requests to read local images or videos "
            "from directories specified by the server file system. "
            "This is a security risk. "
            "Should only be enabled in trusted environments.")
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        # Model loading arguments
        load_kwargs = get_kwargs(LoadConfig)
        load_group = parser.add_argument_group(
            title="LoadConfig",
            description=LoadConfig.__doc__,
        )
        load_group.add_argument('--load-format',
                                choices=[f.value for f in LoadFormat],
                                **load_kwargs["load_format"])
        load_group.add_argument('--download-dir',
                                **load_kwargs["download_dir"])
        load_group.add_argument('--model-loader-extra-config',
                                **load_kwargs["model_loader_extra_config"])
        load_group.add_argument('--use-tqdm-on-load',
                                **load_kwargs["use_tqdm_on_load"])

444
445
446
447
448
449
450
        parser.add_argument(
            '--config-format',
            default=EngineArgs.config_format,
            choices=[f.value for f in ConfigFormat],
            help='The format of the model config to load.\n\n'
            '* "auto" will try to load the config in hf format '
            'if available else it will try to load in mistral format ')
451
452
453
454
        parser.add_argument(
            '--dtype',
            type=str,
            default=EngineArgs.dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
455
456
457
            choices=[
                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
            ],
458
459
460
461
462
463
464
465
            help='Data type for model weights and activations.\n\n'
            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
            'BF16 precision for BF16 models.\n'
            '* "half" for FP16. Recommended for AWQ quantization.\n'
            '* "float16" is the same as "half".\n'
            '* "bfloat16" for a balance between precision and range.\n'
            '* "float" is shorthand for FP32 precision.\n'
            '* "float32" for FP32 precision.')
466
467
468
        parser.add_argument(
            '--kv-cache-dtype',
            type=str,
469
            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
470
            default=EngineArgs.kv_cache_dtype,
471
            help='Data type for kv cache storage. If "auto", will use model '
472
473
            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
474
        parser.add_argument('--max-model-len',
475
                            type=human_readable_int,
476
                            default=EngineArgs.max_model_len,
477
                            help='Model context length. If unspecified, will '
478
479
480
481
482
                            'be automatically derived from the model config. '
                            'Supports k/m/g/K/M/G in human-readable format.\n'
                            'Examples:\n'
                            '- 1k → 1000\n'
                            '- 1K → 1024\n')
483
484
485
486
487
488
489
490

        # Guided decoding arguments
        guided_decoding_kwargs = get_kwargs(DecodingConfig)
        guided_decoding_group = parser.add_argument_group(
            title="DecodingConfig",
            description=DecodingConfig.__doc__,
        )
        guided_decoding_group.add_argument(
491
            '--guided-decoding-backend',
492
493
494
495
496
497
498
            **guided_decoding_kwargs["guided_decoding_backend"])
        guided_decoding_group.add_argument(
            "--reasoning-parser",
            # This choices is a special case because it's not static
            choices=list(ReasoningParserManager.reasoning_parsers),
            **guided_decoding_kwargs["reasoning_backend"])

499
500
        parser.add_argument(
            '--logits-processor-pattern',
501
            type=optional_str,
502
503
504
505
506
            default=None,
            help='Optional regex pattern specifying valid logits processor '
            'qualified names that can be passed with the `logits_processors` '
            'extra completion argument. Defaults to None, which allows no '
            'processors.')
507
508
509
510
511
512
513
514
515
516
517
518
        parser.add_argument(
            '--model-impl',
            type=str,
            default=EngineArgs.model_impl,
            choices=[f.value for f in ModelImpl],
            help='Which implementation of the model to use.\n\n'
            '* "auto" will try to use the vLLM implementation if it exists '
            'and fall back to the Transformers implementation if no vLLM '
            'implementation is available.\n'
            '* "vllm" will use the vLLM model implementation.\n'
            '* "transformers" will use the Transformers model '
            'implementation.\n')
519
        # Parallel arguments
520
521
522
523
524
525
        parallel_kwargs = get_kwargs(ParallelConfig)
        parallel_group = parser.add_argument_group(
            title="ParallelConfig",
            description=ParallelConfig.__doc__,
        )
        parallel_group.add_argument(
526
            '--distributed-executor-backend',
527
528
529
530
531
532
533
534
535
            **parallel_kwargs["distributed_executor_backend"])
        parallel_group.add_argument(
            '--pipeline-parallel-size', '-pp',
            **parallel_kwargs["pipeline_parallel_size"])
        parallel_group.add_argument('--tensor-parallel-size', '-tp',
                                    **parallel_kwargs["tensor_parallel_size"])
        parallel_group.add_argument('--data-parallel-size', '-dp',
                                    **parallel_kwargs["data_parallel_size"])
        parallel_group.add_argument(
536
            '--enable-expert-parallel',
537
538
            **parallel_kwargs["enable_expert_parallel"])
        parallel_group.add_argument(
539
            '--max-parallel-loading-workers',
540
541
            **parallel_kwargs["max_parallel_loading_workers"])
        parallel_group.add_argument(
542
            '--ray-workers-use-nsight',
543
544
545
546
            **parallel_kwargs["ray_workers_use_nsight"])
        parallel_group.add_argument(
            '--disable-custom-all-reduce',
            **parallel_kwargs["disable_custom_all_reduce"])
547
        # KV cache arguments
548
549
        parser.add_argument('--block-size',
                            type=int,
Zhuohan Li's avatar
Zhuohan Li committed
550
                            default=EngineArgs.block_size,
551
                            choices=[8, 16, 32, 64, 128],
552
                            help='Token block size for contiguous chunks of '
553
                            'tokens. This is ignored on neuron devices and '
554
                            'set to ``--max-model-len``. On CUDA devices, '
555
556
                            'only block sizes up to 32 are supported. '
                            'On HPU devices, block size defaults to 128.')
557

558
559
560
561
562
        parser.add_argument(
            "--enable-prefix-caching",
            action=argparse.BooleanOptionalAction,
            default=EngineArgs.enable_prefix_caching,
            help="Enables automatic prefix caching. "
563
            "Use ``--no-enable-prefix-caching`` to disable explicitly.",
564
        )
565
566
567
568
569
570
571
        parser.add_argument(
            "--prefix-caching-hash-algo",
            type=str,
            choices=["builtin", "sha256"],
            default=EngineArgs.prefix_caching_hash_algo,
            help="Set the hash algorithm for prefix caching. "
            "Options are 'builtin' (Python's built-in hash) or 'sha256' "
572
            "(collision resistant but with certain overheads).",
573
        )
574
575
576
        parser.add_argument('--disable-sliding-window',
                            action='store_true',
                            help='Disables sliding window, '
577
                            'capping to sliding window size.')
578
579
        parser.add_argument('--use-v2-block-manager',
                            action='store_true',
580
                            default=True,
581
582
583
584
585
                            help='[DEPRECATED] block manager v1 has been '
                            'removed and SelfAttnBlockSpaceManager (i.e. '
                            'block manager v2) is now the default. '
                            'Setting this flag to True or False'
                            ' has no effect on vLLM behavior.')
586

587
588
589
        parser.add_argument('--seed',
                            type=int,
                            default=EngineArgs.seed,
590
                            help='Random seed for operations.')
591
        parser.add_argument('--swap-space',
592
                            type=float,
Zhuohan Li's avatar
Zhuohan Li committed
593
                            default=EngineArgs.swap_space,
594
                            help='CPU swap space size (GiB) per GPU.')
595
596
597
598
599
600
601
602
603
        parser.add_argument(
            '--cpu-offload-gb',
            type=float,
            default=0,
            help='The space in GiB to offload to CPU, per GPU. '
            'Default is 0, which means no offloading. Intuitively, '
            'this argument can be seen as a virtual way to increase '
            'the GPU memory size. For example, if you have one 24 GB '
            'GPU and set this to 10, virtually you can think of it as '
604
            'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
605
            'which requires at least 26GB GPU memory. Note that this '
606
            'requires fast CPU-GPU interconnect, as part of the model is '
607
608
            'loaded from CPU memory to GPU memory on the fly in each '
            'model forward pass.')
609
610
611
612
        parser.add_argument(
            '--gpu-memory-utilization',
            type=float,
            default=EngineArgs.gpu_memory_utilization,
613
614
615
            help='The fraction of GPU memory to be used for the model '
            'executor, which can range from 0 to 1. For example, a value of '
            '0.5 would imply 50%% GPU memory utilization. If unspecified, '
616
617
618
619
620
621
            'will use the default value of 0.9. This is a per-instance '
            'limit, and only applies to the current vLLM instance.'
            'It does not matter if you have another vLLM instance running '
            'on the same GPU. For example, if you have two vLLM instances '
            'running on the same GPU, you can set the GPU memory utilization '
            'to 0.5 for each instance.')
622
        parser.add_argument(
623
            '--num-gpu-blocks-override',
624
625
626
            type=int,
            default=None,
            help='If specified, ignore GPU profiling result and use this number'
627
            ' of GPU blocks. Used for testing preemption.')
628
629
630
631
        parser.add_argument(
            '--max-logprobs',
            type=int,
            default=EngineArgs.max_logprobs,
632
633
            help=('Max number of log probs to return logprobs is specified in'
                  ' SamplingParams.'))
634
635
        parser.add_argument('--disable-log-stats',
                            action='store_true',
636
                            help='Disable logging statistics.')
637
638
639
        # Quantization settings.
        parser.add_argument('--quantization',
                            '-q',
640
                            type=optional_str,
641
                            choices=[*QUANTIZATION_METHODS, None],
642
                            default=EngineArgs.quantization,
643
644
645
646
647
648
                            help='Method used to quantize the weights. If '
                            'None, we first check the `quantization_config` '
                            'attribute in the model config file. If that is '
                            'None, we assume the model weights are not '
                            'quantized and use `dtype` to determine the data '
                            'type of the weights.')
649
650
651
652
653
        parser.add_argument(
            '--rope-scaling',
            default=None,
            type=json.loads,
            help='RoPE scaling configuration in JSON format. '
654
            'For example, ``{"rope_type":"dynamic","factor":2.0}``')
655
656
657
658
659
660
        parser.add_argument('--rope-theta',
                            default=None,
                            type=float,
                            help='RoPE theta. Use with `rope_scaling`. In '
                            'some cases, changing the RoPE theta improves the '
                            'performance of the scaled model.')
661
662
663
664
665
666
667
668
669
670
        parser.add_argument(
            '--hf-token',
            type=str,
            nargs='?',
            const=True,
            default=None,
            help='The token to use as HTTP bearer authorization'
            ' for remote files. If `True`, will use the token '
            'generated when running `huggingface-cli login` '
            '(stored in `~/.huggingface`).')
671
672
673
        parser.add_argument('--hf-overrides',
                            type=json.loads,
                            default=EngineArgs.hf_overrides,
674
                            help='Extra arguments for the HuggingFace config. '
675
676
                            'This should be a JSON string that will be '
                            'parsed into a dictionary.')
677
678
679
680
681
        parser.add_argument('--enforce-eager',
                            action='store_true',
                            help='Always use eager-mode PyTorch. If False, '
                            'will use eager mode and CUDA graph in hybrid '
                            'for maximal performance and flexibility.')
682
        parser.add_argument('--max-seq-len-to-capture',
683
684
685
686
                            type=int,
                            default=EngineArgs.max_seq_len_to_capture,
                            help='Maximum sequence length covered by CUDA '
                            'graphs. When a sequence has context length '
687
688
689
690
                            'larger than this, we fall back to eager mode. '
                            'Additionally for encoder-decoder models, if the '
                            'sequence length of the encoder input is larger '
                            'than this, we fall back to the eager mode.')
691
692
693
694
695
696
697
698
699
700
701
702
703

        # Tokenizer arguments
        tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
        tokenizer_group = parser.add_argument_group(
            title="TokenizerPoolConfig",
            description=TokenizerPoolConfig.__doc__,
        )
        tokenizer_group.add_argument('--tokenizer-pool-size',
                                     **tokenizer_kwargs["pool_size"])
        tokenizer_group.add_argument('--tokenizer-pool-type',
                                     **tokenizer_kwargs["pool_type"])
        tokenizer_group.add_argument('--tokenizer-pool-extra-config',
                                     **tokenizer_kwargs["extra_config"])
704
705

        # Multimodal related configs
706
707
708
709
710
711
712
713
        multimodal_kwargs = get_kwargs(MultiModalConfig)
        multimodal_group = parser.add_argument_group(
            title="MultiModalConfig",
            description=MultiModalConfig.__doc__,
        )
        multimodal_group.add_argument('--limit-mm-per-prompt',
                                      **multimodal_kwargs["limit_per_prompt"])

714
715
716
717
        parser.add_argument(
            '--mm-processor-kwargs',
            default=None,
            type=json.loads,
718
            help=('Overrides for the multimodal input mapping/processing, '
719
                  'e.g., image processor. For example: ``{"num_crops": 4}``.'))
720
        parser.add_argument(
721
            '--disable-mm-preprocessor-cache',
722
            action='store_true',
723
724
            help='If true, then disables caching of the multi-modal '
            'preprocessor/mapper. (not recommended)')
725

726
727
728
729
        # LoRA related configs
        parser.add_argument('--enable-lora',
                            action='store_true',
                            help='If True, enable handling of LoRA adapters.')
730
731
732
        parser.add_argument('--enable-lora-bias',
                            action='store_true',
                            help='If True, enable bias for LoRA adapters.')
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        parser.add_argument('--max-loras',
                            type=int,
                            default=EngineArgs.max_loras,
                            help='Max number of LoRAs in a single batch.')
        parser.add_argument('--max-lora-rank',
                            type=int,
                            default=EngineArgs.max_lora_rank,
                            help='Max LoRA rank.')
        parser.add_argument(
            '--lora-extra-vocab-size',
            type=int,
            default=EngineArgs.lora_extra_vocab_size,
            help=('Maximum size of extra vocabulary that can be '
                  'present in a LoRA adapter (added to the base '
                  'model vocabulary).'))
        parser.add_argument(
            '--lora-dtype',
            type=str,
            default=EngineArgs.lora_dtype,
752
            choices=['auto', 'float16', 'bfloat16'],
753
754
            help=('Data type for LoRA. If auto, will default to '
                  'base model dtype.'))
755
756
        parser.add_argument(
            '--long-lora-scaling-factors',
757
            type=optional_str,
758
759
760
761
762
763
764
765
            default=EngineArgs.long_lora_scaling_factors,
            help=('Specify multiple scaling factors (which can '
                  'be different from base model scaling factor '
                  '- see eg. Long LoRA) to allow for multiple '
                  'LoRA adapters trained with those scaling '
                  'factors to be used at the same time. If not '
                  'specified, only adapters trained with the '
                  'base model scaling factor are allowed.'))
766
767
768
769
770
        parser.add_argument(
            '--max-cpu-loras',
            type=int,
            default=EngineArgs.max_cpu_loras,
            help=('Maximum number of LoRAs to store in CPU memory. '
771
                  'Must be >= than max_loras.'))
772
773
774
775
776
777
778
779
        parser.add_argument(
            '--fully-sharded-loras',
            action='store_true',
            help=('By default, only half of the LoRA computation is '
                  'sharded with tensor parallelism. '
                  'Enabling this will use the fully sharded layers. '
                  'At high sequence length, max rank or '
                  'tensor parallel size, this is likely faster.'))
780
781
782
783
784
785
786
787
788
789
790
        parser.add_argument('--enable-prompt-adapter',
                            action='store_true',
                            help='If True, enable handling of PromptAdapters.')
        parser.add_argument('--max-prompt-adapters',
                            type=int,
                            default=EngineArgs.max_prompt_adapters,
                            help='Max number of PromptAdapters in a batch.')
        parser.add_argument('--max-prompt-adapter-token',
                            type=int,
                            default=EngineArgs.max_prompt_adapter_token,
                            help='Max number of PromptAdapters tokens')
791
792
793
794
795
796
797
798
799

        # Device arguments
        device_kwargs = get_kwargs(DeviceConfig)
        device_group = parser.add_argument_group(
            title="DeviceConfig",
            description=DeviceConfig.__doc__,
        )
        device_group.add_argument("--device", **device_kwargs["device"])

800
801
802
803
804
        parser.add_argument('--num-scheduler-steps',
                            type=int,
                            default=1,
                            help=('Maximum number of forward steps per '
                                  'scheduler call.'))
805

806
        parser.add_argument('--speculative-config',
807
                            type=json.loads,
808
809
810
                            default=None,
                            help='The configurations for speculative decoding.'
                            ' Should be a JSON string.')
811
812
813
814
815
816
        parser.add_argument(
            '--ignore-patterns',
            action="append",
            type=str,
            default=[],
            help="The pattern(s) to ignore when loading the model."
817
            "Default to `original/**/*` to avoid repeated loading of llama's "
818
            "checkpoints.")
819
        parser.add_argument(
820
            '--preemption-mode',
821
822
            type=str,
            default=None,
823
824
825
            help='If \'recompute\', the engine performs preemption by '
            'recomputing; If \'swap\', the engine performs preemption by '
            'block swapping.')
826

827
828
829
830
831
832
833
834
835
836
        parser.add_argument(
            "--served-model-name",
            nargs="+",
            type=str,
            default=None,
            help="The model name(s) used in the API. If multiple "
            "names are provided, the server will respond to any "
            "of the provided names. The model name in the model "
            "field of a response will be the first name in this "
            "list. If not specified, the model name will be the "
837
            "same as the ``--model`` argument. Noted that this name(s) "
838
            "will also be used in `model_name` tag content of "
839
            "prometheus metrics, if multiple names provided, metrics "
840
            "tag will take the first one.")
841
842
843
844
        parser.add_argument('--qlora-adapter-name-or-path',
                            type=str,
                            default=None,
                            help='Name or path of the QLoRA adapter.')
845

846
847
848
849
850
851
852
853
854
855
856
857
        parser.add_argument('--show-hidden-metrics-for-version',
                            type=str,
                            default=None,
                            help='Enable deprecated Prometheus metrics that '
                            'have been hidden since the specified version. '
                            'For example, if a previously deprecated metric '
                            'has been hidden since the v0.7.0 release, you '
                            'use --show-hidden-metrics-for-version=0.7 as a '
                            'temporary escape hatch while you migrate to new '
                            'metrics. The metric is likely to be removed '
                            'completely in an upcoming release.')

858
859
860
861
862
        parser.add_argument(
            '--otlp-traces-endpoint',
            type=str,
            default=None,
            help='Target URL to which OpenTelemetry traces will be sent.')
863
864
865
866
867
868
        parser.add_argument(
            '--collect-detailed-traces',
            type=str,
            default=None,
            help="Valid choices are " +
            ",".join(ALLOWED_DETAILED_TRACE_MODULES) +
869
            ". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
870
871
872
            " set. If set, it will collect detailed traces for the specified "
            "modules. This involves use of possibly costly and or blocking "
            "operations and hence might have a performance impact.")
873

874
875
876
877
878
879
        parser.add_argument(
            '--disable-async-output-proc',
            action='store_true',
            default=EngineArgs.disable_async_output_proc,
            help="Disable async output processing. This may result in "
            "lower performance.")
880

881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
        # Scheduler arguments
        scheduler_kwargs = get_kwargs(SchedulerConfig)
        scheduler_group = parser.add_argument_group(
            title="SchedulerConfig",
            description=SchedulerConfig.__doc__,
        )
        scheduler_group.add_argument(
            '--max-num-batched-tokens',
            **scheduler_kwargs["max_num_batched_tokens"])
        scheduler_group.add_argument('--max-num-seqs',
                                     **scheduler_kwargs["max_num_seqs"])
        scheduler_group.add_argument(
            "--max-num-partial-prefills",
            **scheduler_kwargs["max_num_partial_prefills"])
        scheduler_group.add_argument(
            "--max-long-partial-prefills",
            **scheduler_kwargs["max_long_partial_prefills"])
        scheduler_group.add_argument(
            "--long-prefill-token-threshold",
            **scheduler_kwargs["long_prefill_token_threshold"])
        scheduler_group.add_argument('--num-lookahead-slots',
                                     **scheduler_kwargs["num_lookahead_slots"])
        scheduler_group.add_argument('--scheduler-delay-factor',
                                     **scheduler_kwargs["delay_factor"])
        scheduler_group.add_argument(
            '--enable-chunked-prefill',
            **scheduler_kwargs["enable_chunked_prefill"])
        scheduler_group.add_argument(
            '--multi-step-stream-outputs',
            **scheduler_kwargs["multi_step_stream_outputs"])
        scheduler_group.add_argument('--scheduling-policy',
                                     **scheduler_kwargs["policy"])
        scheduler_group.add_argument(
            "--disable-chunked-mm-input",
            **scheduler_kwargs["disable_chunked_mm_input"])
        parser.add_argument('--scheduler-cls',
                            **scheduler_kwargs["scheduler_cls"])
918

919
        parser.add_argument(
920
921
            '--override-neuron-config',
            type=json.loads,
922
            default=None,
923
            help="Override or set neuron device configuration. "
924
            "e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
925
        parser.add_argument(
926
927
            '--override-pooler-config',
            type=PoolerConfig.from_json,
928
            default=None,
929
            help="Override or set the pooling method for pooling models. "
930
            "e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
931

932
933
934
935
936
937
938
939
940
941
942
943
        parser.add_argument('--compilation-config',
                            '-O',
                            type=CompilationConfig.from_cli,
                            default=None,
                            help='torch.compile configuration for the model.'
                            'When it is a number (0, 1, 2, 3), it will be '
                            'interpreted as the optimization level.\n'
                            'NOTE: level 0 is the default level without '
                            'any optimization. level 1 and 2 are for internal '
                            'testing only. level 3 is the recommended level '
                            'for production.\n'
                            'To specify the full compilation config, '
944
945
                            'use a JSON string, e.g. ``{"level": 3, '
                            '"cudagraph_capture_sizes": [1, 2, 4, 8]}``\n'
946
                            'Following the convention of traditional '
947
948
                            'compilers, using ``-O`` without space is also '
                            'supported. ``-O3`` is equivalent to ``-O 3``.')
949

950
951
952
953
954
955
        parser.add_argument('--kv-transfer-config',
                            type=KVTransferConfig.from_cli,
                            default=None,
                            help='The configurations for distributed KV cache '
                            'transfer. Should be a JSON string.')

956
957
958
959
960
        parser.add_argument(
            '--worker-cls',
            type=str,
            default="auto",
            help='The worker class to use for distributed execution.')
961
962
963
964
965
966
967
        parser.add_argument(
            '--worker-extension-cls',
            type=str,
            default="",
            help='The worker extension class on top of the worker cls, '
            'it is useful if you just want to add new functions to the worker '
            'class without changing the existing functions.')
968
969
        parser.add_argument(
            "--generation-config",
970
            type=optional_str,
971
            default="auto",
972
            help="The folder path to the generation config. "
973
974
975
976
977
            "Defaults to 'auto', the generation config will be loaded from "
            "model path. If set to 'vllm', no generation config is loaded, "
            "vLLM defaults will be used. If set to a folder path, the "
            "generation config will be loaded from the specified folder path. "
            "If `max_new_tokens` is specified in generation config, then "
978
979
980
981
982
983
984
985
986
987
988
989
            "it sets a server-wide limit on the number of output tokens "
            "for all requests.")

        parser.add_argument(
            "--override-generation-config",
            type=json.loads,
            default=None,
            help="Overrides or sets generation config in JSON format. "
            "e.g. ``{\"temperature\": 0.5}``. If used with "
            "--generation-config=auto, the override parameters will be merged "
            "with the default config from the model. If generation-config is "
            "None, only the override parameters are used.")
990

991
992
993
994
995
996
        parser.add_argument("--enable-sleep-mode",
                            action="store_true",
                            default=False,
                            help="Enable sleep mode for the engine. "
                            "(only cuda platform is supported)")

997
998
999
1000
1001
1002
1003
1004
1005
        parser.add_argument(
            '--calculate-kv-scales',
            action='store_true',
            help='This enables dynamic calculation of '
            'k_scale and v_scale when kv-cache-dtype is fp8. '
            'If calculate-kv-scales is false, the scales will '
            'be loaded from the model checkpoint if available. '
            'Otherwise, the scales will default to 1.0.')

1006
1007
1008
1009
1010
1011
1012
1013
        parser.add_argument(
            "--additional-config",
            type=json.loads,
            default=None,
            help="Additional config for specified platform in JSON format. "
            "Different platforms may support different configs. Make sure the "
            "configs are valid for the platform you are using. The input format"
            " is like '{\"config_key\":\"config_value\"}'")
1014
1015
1016
1017
1018
1019
1020
1021
1022

        parser.add_argument(
            "--enable-reasoning",
            action="store_true",
            default=False,
            help="Whether to enable reasoning_content for the model. "
            "If enabled, the model will be able to generate reasoning content."
        )

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
        parser.add_argument(
            "--disable-cascade-attn",
            action="store_true",
            default=False,
            help="Disable cascade attention for V1. While cascade attention "
            "does not change the mathematical correctness, disabling it "
            "could be useful for preventing potential numerical issues. "
            "Note that even if this is set to False, cascade attention will be "
            "only used when the heuristic tells that it's beneficial.")

1033
        return parser
1034
1035

    @classmethod
1036
    def from_cli_args(cls, args: argparse.Namespace):
1037
1038
1039
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
Zhuohan Li's avatar
Zhuohan Li committed
1040
1041
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args
1042

1043
    def create_model_config(self) -> ModelConfig:
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        # gguf file needs a specific model loader and doesn't use hf_repo
        if check_gguf_file(self.model):
            self.quantization = self.load_format = "gguf"

        # NOTE: This is to allow model loading from S3 in CI
        if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
                and self.model in MODELS_ON_S3
                and self.load_format == LoadFormat.AUTO):  # noqa: E501
            self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
            self.load_format = LoadFormat.RUNAI_STREAMER

1055
        return ModelConfig(
1056
            model=self.model,
1057
            hf_config_path=self.hf_config_path,
1058
            task=self.task,
1059
1060
            # We know this is not None because we set it in __post_init__
            tokenizer=cast(str, self.tokenizer),
1061
1062
            tokenizer_mode=self.tokenizer_mode,
            trust_remote_code=self.trust_remote_code,
1063
            allowed_local_media_path=self.allowed_local_media_path,
1064
1065
1066
1067
1068
            dtype=self.dtype,
            seed=self.seed,
            revision=self.revision,
            code_revision=self.code_revision,
            rope_scaling=self.rope_scaling,
1069
            rope_theta=self.rope_theta,
1070
            hf_token=self.hf_token,
1071
            hf_overrides=self.hf_overrides,
1072
1073
1074
1075
1076
1077
1078
            tokenizer_revision=self.tokenizer_revision,
            max_model_len=self.max_model_len,
            quantization=self.quantization,
            enforce_eager=self.enforce_eager,
            max_seq_len_to_capture=self.max_seq_len_to_capture,
            max_logprobs=self.max_logprobs,
            disable_sliding_window=self.disable_sliding_window,
1079
            disable_cascade_attn=self.disable_cascade_attn,
1080
            skip_tokenizer_init=self.skip_tokenizer_init,
1081
            served_model_name=self.served_model_name,
1082
            limit_mm_per_prompt=self.limit_mm_per_prompt,
1083
            use_async_output_proc=not self.disable_async_output_proc,
1084
            config_format=self.config_format,
1085
            mm_processor_kwargs=self.mm_processor_kwargs,
1086
            disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
1087
1088
            override_neuron_config=self.override_neuron_config,
            override_pooler_config=self.override_pooler_config,
1089
            logits_processor_pattern=self.logits_processor_pattern,
1090
            generation_config=self.generation_config,
1091
            override_generation_config=self.override_generation_config,
1092
            enable_sleep_mode=self.enable_sleep_mode,
1093
            model_impl=self.model_impl,
1094
        )
1095

1096
1097
    def create_load_config(self) -> LoadConfig:

1098
        if(self.qlora_adapter_name_or_path is not None) and \
1099
1100
            self.quantization != "bitsandbytes":
            raise ValueError(
1101
                "QLoRA adapter only support "
1102
1103
                f"'bitsandbytes' quantization, but got {self.quantization}")

1104
1105
        if self.quantization == "bitsandbytes":
            self.load_format = "bitsandbytes"
1106
1107
1108
1109
1110
        return LoadConfig(
            load_format=self.load_format,
            download_dir=self.download_dir,
            model_loader_extra_config=self.model_loader_extra_config,
            ignore_patterns=self.ignore_patterns,
1111
            use_tqdm_on_load=self.use_tqdm_on_load,
1112
        )
1113

1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
    def create_speculative_config(
        self,
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        enable_chunked_prefill: bool,
        disable_log_stats: bool,
    ) -> Optional["SpeculativeConfig"]:
        """Initializes and returns a SpeculativeConfig object based on
        `speculative_config`.

        This function utilizes `speculative_config` to create a
        SpeculativeConfig object. The `speculative_config` can either be
        provided as a JSON string input via CLI arguments or directly as a
1127
        dictionary from the engine.
1128
1129
        """
        if self.speculative_config is None:
1130
1131
            return None

1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
        # Note(Shangming): These parameters are not obtained from the cli arg
        # '--speculative-config' and must be passed in when creating the engine
        # config.
        self.speculative_config.update({
            "target_model_config": target_model_config,
            "target_parallel_config": target_parallel_config,
            "enable_chunked_prefill": enable_chunked_prefill,
            "disable_log_stats": disable_log_stats,
        })
        speculative_config = SpeculativeConfig.from_dict(
            self.speculative_config)

        return speculative_config

1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
    def create_engine_config(
        self,
        usage_context: Optional[UsageContext] = None,
    ) -> VllmConfig:
        """
        Create the VllmConfig.

        NOTE: for autoselection of V0 vs V1 engine, we need to
        create the ModelConfig first, since ModelConfig's attrs
        (e.g. the model arch) are needed to make the decision.
Simon Mo's avatar
Simon Mo committed
1156

1157
1158
1159
1160
1161
1162
        This function set VLLM_USE_V1=X if VLLM_USE_V1 is
        unspecified by the user.

        If VLLM_USE_V1 is specified by the user but the VllmConfig
        is incompatible, we raise an error.
        """
1163
1164
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update()
1165

1166
        device_config = DeviceConfig(device=self.device)
1167
1168
        model_config = self.create_model_config()

1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
        # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
        #   and fall back to V0 for experimental or unsupported features.
        # * If VLLM_USE_V1=1, we enable V1 for supported + experimental
        #   features and raise error for unsupported features.
        # * If VLLM_USE_V1=0, we disable V1.
        use_v1 = False
        try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
        if try_v1 and self._is_v1_supported_oracle(model_config):
            use_v1 = True

        # If user explicitly set VLLM_USE_V1, sanity check we respect it.
        if envs.is_set("VLLM_USE_V1"):
            assert use_v1 == envs.VLLM_USE_V1
        # Otherwise, set the VLLM_USE_V1 variable globally.
        else:
            envs.set_vllm_use_v1(use_v1)

        # Set default arguments for V0 or V1 Engine.
        if use_v1:
            self._set_default_args_v1(usage_context)
        else:
            self._set_default_args_v0(model_config)
1191

1192
1193
        assert self.enable_chunked_prefill is not None

1194
        cache_config = CacheConfig(
1195
            block_size=self.block_size,
1196
1197
1198
            gpu_memory_utilization=self.gpu_memory_utilization,
            swap_space=self.swap_space,
            cache_dtype=self.kv_cache_dtype,
1199
            is_attention_free=model_config.is_attention_free,
1200
1201
            num_gpu_blocks_override=self.num_gpu_blocks_override,
            sliding_window=model_config.get_sliding_window(),
1202
            enable_prefix_caching=self.enable_prefix_caching,
1203
            prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1204
            cpu_offload_gb=self.cpu_offload_gb,
1205
            calculate_kv_scales=self.calculate_kv_scales,
1206
        )
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218

        # Get the current placement group if Ray is initialized and
        # we are in a Ray actor. If so, then the placement group will be
        # passed to spawned processes.
        placement_group = None
        if is_in_ray_actor():
            import ray

            # This call initializes Ray automatically if it is not initialized,
            # but we should not do this here.
            placement_group = ray.util.get_current_placement_group()

1219
        parallel_config = ParallelConfig(
1220
1221
            pipeline_parallel_size=self.pipeline_parallel_size,
            tensor_parallel_size=self.tensor_parallel_size,
1222
            data_parallel_size=self.data_parallel_size,
1223
            enable_expert_parallel=self.enable_expert_parallel,
1224
1225
1226
            max_parallel_loading_workers=self.max_parallel_loading_workers,
            disable_custom_all_reduce=self.disable_custom_all_reduce,
            tokenizer_pool_config=TokenizerPoolConfig.create_config(
1227
1228
1229
                self.tokenizer_pool_size,
                self.tokenizer_pool_type,
                self.tokenizer_pool_extra_config,
1230
            ),
1231
            ray_workers_use_nsight=self.ray_workers_use_nsight,
1232
            placement_group=placement_group,
1233
1234
            distributed_executor_backend=self.distributed_executor_backend,
            worker_cls=self.worker_cls,
1235
            worker_extension_cls=self.worker_extension_cls,
1236
        )
1237

1238
        speculative_config = self.create_speculative_config(
1239
1240
            target_model_config=model_config,
            target_parallel_config=parallel_config,
1241
            enable_chunked_prefill=self.enable_chunked_prefill,
1242
            disable_log_stats=self.disable_log_stats,
1243
1244
        )

1245
        # Reminder: Please update docs/source/features/compatibility_matrix.md
1246
        # If the feature combo become valid
1247
1248
1249
1250
        if self.num_scheduler_steps > 1:
            if speculative_config is not None:
                raise ValueError("Speculative decoding is not supported with "
                                 "multi-step (--num-scheduler-steps > 1)")
1251
1252
1253
            if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
                raise ValueError("Multi-Step Chunked-Prefill is not supported "
                                 "for pipeline-parallel-size > 1")
1254
1255
1256
1257
1258
1259
            from vllm.platforms import current_platform
            if current_platform.is_cpu():
                logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
                               "currently not supported for CPUs and has been "
                               "disabled.")
                self.num_scheduler_steps = 1
1260
1261
1262
1263
1264
1265
1266
1267
1268

        # make sure num_lookahead_slots is set the higher value depending on
        # if we are using speculative decoding or multi-step
        num_lookahead_slots = max(self.num_lookahead_slots,
                                  self.num_scheduler_steps - 1)
        num_lookahead_slots = num_lookahead_slots \
            if speculative_config is None \
            else speculative_config.num_lookahead_slots

1269
        scheduler_config = SchedulerConfig(
1270
            runner_type=model_config.runner_type,
1271
1272
1273
            max_num_batched_tokens=self.max_num_batched_tokens,
            max_num_seqs=self.max_num_seqs,
            max_model_len=model_config.max_model_len,
1274
            num_lookahead_slots=num_lookahead_slots,
1275
1276
            delay_factor=self.scheduler_delay_factor,
            enable_chunked_prefill=self.enable_chunked_prefill,
1277
            disable_chunked_mm_input=self.disable_chunked_mm_input,
1278
            is_multimodal_model=model_config.is_multimodal_model,
1279
            preemption_mode=self.preemption_mode,
1280
            num_scheduler_steps=self.num_scheduler_steps,
1281
            multi_step_stream_outputs=self.multi_step_stream_outputs,
1282
1283
            send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
                             and parallel_config.use_ray),
1284
            policy=self.scheduling_policy,
1285
            scheduler_cls=self.scheduler_cls,
1286
1287
1288
1289
            max_num_partial_prefills=self.max_num_partial_prefills,
            max_long_partial_prefills=self.max_long_partial_prefills,
            long_prefill_token_threshold=self.long_prefill_token_threshold,
        )
1290

1291
        lora_config = LoRAConfig(
1292
            bias_enabled=self.enable_lora_bias,
1293
1294
            max_lora_rank=self.max_lora_rank,
            max_loras=self.max_loras,
1295
            fully_sharded_loras=self.fully_sharded_loras,
1296
            lora_extra_vocab_size=self.lora_extra_vocab_size,
1297
            long_lora_scaling_factors=self.long_lora_scaling_factors,
1298
1299
1300
            lora_dtype=self.lora_dtype,
            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
1301

1302
1303
1304
1305
1306
        if self.qlora_adapter_name_or_path is not None and \
            self.qlora_adapter_name_or_path != "":
            self.model_loader_extra_config[
                "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

1307
1308
1309
1310
        # bitsandbytes pre-quantized model need a specific model loader
        if model_config.quantization == "bitsandbytes":
            self.quantization = self.load_format = "bitsandbytes"

1311
        load_config = self.create_load_config()
1312

1313
1314
1315
1316
1317
        prompt_adapter_config = PromptAdapterConfig(
            max_prompt_adapters=self.max_prompt_adapters,
            max_prompt_adapter_token=self.max_prompt_adapter_token) \
                                        if self.enable_prompt_adapter else None

1318
        decoding_config = DecodingConfig(
1319
1320
1321
1322
            guided_decoding_backend=self.guided_decoding_backend,
            reasoning_backend=self.reasoning_parser
            if self.enable_reasoning else None,
        )
1323

1324
1325
1326
1327
1328
        show_hidden_metrics = False
        if self.show_hidden_metrics_for_version is not None:
            show_hidden_metrics = version._prev_minor_version_was(
                self.show_hidden_metrics_for_version)

1329
1330
1331
1332
1333
1334
1335
1336
        detailed_trace_modules = []
        if self.collect_detailed_traces is not None:
            detailed_trace_modules = self.collect_detailed_traces.split(",")
        for m in detailed_trace_modules:
            if m not in ALLOWED_DETAILED_TRACE_MODULES:
                raise ValueError(
                    f"Invalid module {m} in collect_detailed_traces. "
                    f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
1337
        observability_config = ObservabilityConfig(
1338
            show_hidden_metrics=show_hidden_metrics,
1339
1340
1341
1342
1343
1344
            otlp_traces_endpoint=self.otlp_traces_endpoint,
            collect_model_forward_time="model" in detailed_trace_modules
            or "all" in detailed_trace_modules,
            collect_model_execute_time="worker" in detailed_trace_modules
            or "all" in detailed_trace_modules,
        )
1345

1346
        config = VllmConfig(
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            speculative_config=speculative_config,
            load_config=load_config,
            decoding_config=decoding_config,
            observability_config=observability_config,
1357
            prompt_adapter_config=prompt_adapter_config,
1358
            compilation_config=self.compilation_config,
1359
            kv_transfer_config=self.kv_transfer_config,
1360
            additional_config=self.additional_config,
1361
        )
1362

1363
1364
        return config

1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
    def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
        """Oracle for whether to use V0 or V1 Engine by default."""

        #############################################################
        # Unsupported Feature Flags on V1.

        if (self.load_format == LoadFormat.TENSORIZER.value
                or self.load_format == LoadFormat.SHARDED_STATE.value):
            _raise_or_fallback(
                feature_name=f"--load_format {self.load_format}",
                recommend_to_remove=False)
            return False

        if (self.logits_processor_pattern
                != EngineArgs.logits_processor_pattern):
            _raise_or_fallback(feature_name="--logits-processor-pattern",
                               recommend_to_remove=False)
            return False

1384
        if self.preemption_mode != SchedulerConfig.preemption_mode:
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
            _raise_or_fallback(feature_name="--preemption-mode",
                               recommend_to_remove=True)
            return False

        if (self.disable_async_output_proc
                != EngineArgs.disable_async_output_proc):
            _raise_or_fallback(feature_name="--disable-async-output-proc",
                               recommend_to_remove=True)
            return False

1395
        if self.scheduling_policy != SchedulerConfig.policy:
1396
1397
1398
1399
            _raise_or_fallback(feature_name="--scheduling-policy",
                               recommend_to_remove=False)
            return False

1400
        if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
1401
1402
1403
1404
            _raise_or_fallback(feature_name="--num-scheduler-steps",
                               recommend_to_remove=True)
            return False

1405
        if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
1406
1407
1408
1409
1410
1411
1412
1413
1414
            _raise_or_fallback(feature_name="--scheduler-delay-factor",
                               recommend_to_remove=True)
            return False

        if self.additional_config != EngineArgs.additional_config:
            _raise_or_fallback(feature_name="--additional-config",
                               recommend_to_remove=False)
            return False

1415
        # Xgrammar and Guidance are supported.
1416
        SUPPORTED_GUIDED_DECODING = [
1417
1418
            "xgrammar", "xgrammar:disable-any-whitespace", "guidance",
            "guidance:disable-any-whitespace", "auto"
1419
        ]
1420
1421
1422
1423
1424
1425
        if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
            _raise_or_fallback(feature_name="--guided-decoding-backend",
                               recommend_to_remove=False)
            return False

        # Need at least Ampere for now (FA support required).
1426
1427
1428
        # Skip this check if we are running on a non-GPU platform,
        # or if the device capability is not available
        # (e.g. in a Ray actor without GPUs).
1429
1430
        from vllm.platforms import current_platform
        if (current_platform.is_cuda()
1431
                and current_platform.get_device_capability()
1432
1433
1434
1435
1436
1437
1438
                and current_platform.get_device_capability().major < 8):
            _raise_or_fallback(feature_name="Compute Capability < 8.0",
                               recommend_to_remove=False)
            return False

        # No Fp8 KV cache so far.
        if self.kv_cache_dtype != "auto":
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
            fp8_attention = self.kv_cache_dtype.startswith("fp8")
            will_use_fa = (
                current_platform.is_cuda()
                and not envs.is_set("VLLM_ATTENTION_BACKEND")
            ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
            supported = False
            if fp8_attention and will_use_fa:
                from vllm.vllm_flash_attn.fa_utils import (
                    flash_attn_supports_fp8)
                supported = flash_attn_supports_fp8()
            if not supported:
                _raise_or_fallback(feature_name="--kv-cache-dtype",
                                   recommend_to_remove=False)
                return False
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467

        # No Prompt Adapter so far.
        if self.enable_prompt_adapter:
            _raise_or_fallback(feature_name="--enable-prompt-adapter",
                               recommend_to_remove=False)
            return False

        # Only Fp16 and Bf16 dtypes since we only support FA.
        V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
        if model_config.dtype not in V1_SUPPORTED_DTYPES:
            _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
                               recommend_to_remove=False)
            return False

        # Some quantization is not compatible with torch.compile.
1468
        V1_UNSUPPORTED_QUANT = ["gguf"]
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
        if model_config.quantization in V1_UNSUPPORTED_QUANT:
            _raise_or_fallback(
                feature_name=f"--quantization {model_config.quantization}",
                recommend_to_remove=False)
            return False

        # No Embedding Models so far.
        if model_config.task not in ["generate"]:
            _raise_or_fallback(feature_name=f"--task {model_config.task}",
                               recommend_to_remove=False)
            return False

        # No Mamba or Encoder-Decoder so far.
        if not model_config.is_v1_compatible:
            _raise_or_fallback(feature_name=model_config.architectures,
                               recommend_to_remove=False)
            return False

        # No Concurrent Partial Prefills so far.
        if (self.max_num_partial_prefills
1489
                != SchedulerConfig.max_num_partial_prefills
1490
                or self.max_long_partial_prefills
1491
                != SchedulerConfig.max_long_partial_prefills):
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
            _raise_or_fallback(feature_name="Concurrent Partial Prefill",
                               recommend_to_remove=False)
            return False

        # No OTLP observability so far.
        if (self.otlp_traces_endpoint or self.collect_detailed_traces):
            _raise_or_fallback(feature_name="--otlp-traces-endpoint",
                               recommend_to_remove=False)
            return False

        # Only Ngram speculative decoding so far.
1503
        is_ngram_enabled = False
1504
        is_eagle_enabled = False
1505
        if self.speculative_config is not None:
1506
            # This is supported but experimental (handled below).
1507
1508
1509
1510
1511
1512
            speculative_method = self.speculative_config.get("method")
            if speculative_method:
                if speculative_method in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
                elif speculative_method == "eagle":
                    is_eagle_enabled = True
1513
            else:
1514
1515
1516
1517
1518
                speculative_model = self.speculative_config.get("model")
                if speculative_model in ("ngram", "[ngram]"):
                    is_ngram_enabled = True
            if not (is_ngram_enabled or is_eagle_enabled):
                # Other speculative decoding methods are not supported yet.
1519
1520
1521
1522
1523
1524
1525
                _raise_or_fallback(feature_name="Speculative Decoding",
                                   recommend_to_remove=False)
                return False

        # No FlashInfer or XFormers so far.
        V1_BACKENDS = [
            "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
1526
            "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
1527
1528
1529
1530
1531
1532
1533
        ]
        if (envs.is_set("VLLM_ATTENTION_BACKEND")
                and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
            name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
            _raise_or_fallback(feature_name=name, recommend_to_remove=True)
            return False

1534
1535
        # Platforms must decide if they can support v1 for this model
        if not current_platform.supports_v1(model_config=model_config):
1536
1537
1538
1539
            _raise_or_fallback(
                feature_name=f"device type={current_platform.device_type}",
                recommend_to_remove=False)
            return False
1540
1541
1542
        #############################################################
        # Experimental Features - allow users to opt in.

1543
1544
1545
1546
1547
        # Signal Handlers requires running in main thread.
        if (threading.current_thread() != threading.main_thread()
                and _warn_or_fallback("Engine in background thread")):
            return False

1548
1549
1550
        # PP is supported on V1 with Ray distributed executor,
        # but off for MP distributed executor for now.
        if (self.pipeline_parallel_size > 1
1551
1552
1553
                and self.distributed_executor_backend != "ray"):
            name = "Pipeline Parallelism without Ray distributed executor"
            _raise_or_fallback(feature_name=name, recommend_to_remove=False)
1554
1555
1556
            return False

        # ngram is supported on V1, but off by default for now.
1557
        if is_ngram_enabled and _warn_or_fallback("ngram"):
1558
1559
            return False

1560
1561
1562
1563
        # Eagle is under development, so we don't support it yet.
        if is_eagle_enabled and _warn_or_fallback("Eagle"):
            return False

1564
1565
1566
        # Non-CUDA is supported on V1, but off by default for now.
        not_cuda = not current_platform.is_cuda()
        if not_cuda and _warn_or_fallback(  # noqa: SIM103
1567
                current_platform.device_name):
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
            return False
        #############################################################

        return True

    def _set_default_args_v0(self, model_config: ModelConfig) -> None:
        """Set Default Arguments for V0 Engine."""

        max_model_len = model_config.max_model_len
        use_long_context = max_model_len > 32768
        if self.enable_chunked_prefill is None:
            # Chunked prefill not supported for Multimodal or MLA in V0.
            if model_config.is_multimodal_model or model_config.use_mla:
                self.enable_chunked_prefill = False

            # Enable chunked prefill by default for long context (> 32K)
            # models to avoid OOM errors in initial memory profiling phase.
            elif use_long_context:
                from vllm.platforms import current_platform
                is_gpu = current_platform.is_cuda()
                use_sliding_window = (model_config.get_sliding_window()
                                      is not None)
1590
                use_spec_decode = self.speculative_config is not None
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617

                if (is_gpu and not use_sliding_window and not use_spec_decode
                        and not self.enable_lora
                        and not self.enable_prompt_adapter
                        and model_config.runner_type != "pooling"):
                    self.enable_chunked_prefill = True
                    logger.warning(
                        "Chunked prefill is enabled by default for models "
                        "with max_model_len > 32K. Chunked prefill might "
                        "not work with some features or models. If you "
                        "encounter any issues, please disable by launching "
                        "with --enable-chunked-prefill=False.")

            if self.enable_chunked_prefill is None:
                self.enable_chunked_prefill = False

        if not self.enable_chunked_prefill and use_long_context:
            logger.warning(
                "The model has a long context length (%s). This may cause"
                "OOM during the initial memory profiling phase, or result "
                "in low performance due to small KV cache size. Consider "
                "setting --max-model-len to a smaller value.", max_model_len)
        elif (self.enable_chunked_prefill
              and model_config.runner_type == "pooling"):
            msg = "Chunked prefill is not supported for pooling models"
            raise ValueError(msg)

1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching:
            # Disable prefix caching for multimodal models for VLLM_V0.
            if model_config.is_multimodal_model:
                logger.warning(
                    "--enable-prefix-caching is not supported for multimodal "
                    "models in V0 and has been disabled.")
                self.enable_prefix_caching = False

            # VLLM_V0 only supports builtin hash algo for prefix caching.
            if self.prefix_caching_hash_algo is None:
                self.prefix_caching_hash_algo = "builtin"
            elif self.prefix_caching_hash_algo == "sha256":
                raise ValueError(
                    "sha256 is not supported for prefix caching in V0 engine. "
                    "Please use 'builtin'.")
1634
1635
1636
1637
1638
1639
1640

        # Set max_num_seqs to 256 for VLLM_V0.
        if self.max_num_seqs is None:
            self.max_num_seqs = 256

    def _set_default_args_v1(self, usage_context: UsageContext) -> None:
        """Set Default Arguments for V1 Engine."""
1641

1642
1643
        # V1 always uses chunked prefills.
        self.enable_chunked_prefill = True
1644
1645
1646
1647
1648

        # V1 enables prefix caching by default.
        if self.enable_prefix_caching is None:
            self.enable_prefix_caching = True

1649
1650
1651
1652
        # if using prefix caching, we must set a hash algo
        if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
            self.prefix_caching_hash_algo = "builtin"

1653
1654
1655
        # V1 should use the new scheduler by default.
        # Swap it only if this arg is set to the original V0 default
        if self.scheduler_cls == EngineArgs.scheduler_cls:
1656
            self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
1657

1658
1659
        # When no user override, set the default values based on the usage
        # context.
1660
        # Use different default values for different hardware.
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673

        # Try to query the device name on the current platform. If it fails,
        # it may be because the platform that imports vLLM is not the same
        # as the platform that vLLM is running on (e.g. the case of scaling
        # vLLM with Ray) and has no GPUs. In this case we use the default
        # values for non-H100/H200 GPUs.
        try:
            from vllm.platforms import current_platform
            device_name = current_platform.get_device_name().lower()
        except Exception:
            # This is only used to set default_max_num_batched_tokens
            device_name = "no-device"

1674
1675
1676
1677
1678
1679
        if "h100" in device_name or "h200" in device_name:
            # For H100 and H200, we use larger default values.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 16384,
                UsageContext.OPENAI_API_SERVER: 8192,
            }
1680
            default_max_num_seqs = 1024
1681
1682
1683
1684
1685
1686
        else:
            # TODO(woosuk): Tune the default values for other hardware.
            default_max_num_batched_tokens = {
                UsageContext.LLM_CLASS: 8192,
                UsageContext.OPENAI_API_SERVER: 2048,
            }
1687
            default_max_num_seqs = 256
1688

1689
        use_context_value = usage_context.value if usage_context else None
1690
1691
1692
1693
        if (self.max_num_batched_tokens is None
                and usage_context in default_max_num_batched_tokens):
            self.max_num_batched_tokens = default_max_num_batched_tokens[
                usage_context]
1694
            logger.debug(
1695
                "Setting max_num_batched_tokens to %d for %s usage context.",
1696
                self.max_num_batched_tokens, use_context_value)
1697

1698
1699
1700
1701
1702
        if self.max_num_seqs is None:
            self.max_num_seqs = default_max_num_seqs

            logger.debug("Setting max_num_seqs to %d for %s usage context.",
                         self.max_num_seqs, use_context_value)
1703

1704

1705
@dataclass
Zhuohan Li's avatar
Zhuohan Li committed
1706
class AsyncEngineArgs(EngineArgs):
Woosuk Kwon's avatar
Woosuk Kwon committed
1707
    """Arguments for asynchronous vLLM engine."""
1708
    disable_log_requests: bool = False
1709
1710

    @staticmethod
1711
1712
    def add_cli_args(parser: FlexibleArgumentParser,
                     async_args_only: bool = False) -> FlexibleArgumentParser:
1713
1714
1715
1716
        # Initialize plugin to update the parser, for example, The plugin may
        # adding a new kind of quantization method to --quantization argument or
        # a new device to --device argument.
        load_general_plugins()
1717
1718
        if not async_args_only:
            parser = EngineArgs.add_cli_args(parser)
1719
1720
        parser.add_argument('--disable-log-requests',
                            action='store_true',
1721
                            help='Disable logging requests.')
1722
1723
        from vllm.platforms import current_platform
        current_platform.pre_register_and_update(parser)
1724
        return parser
1725
1726


1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        raise NotImplementedError(
            f"VLLM_USE_V1=1 is not supported with {feature_name}.")
    msg = f"{feature_name} is not supported by the V1 Engine. "
    msg += "Falling back to V0. "
    if recommend_to_remove:
        msg += f"We recommend to remove {feature_name} from your config "
        msg += "in favor of the V1 Engine."
    logger.warning(msg)


def _warn_or_fallback(feature_name: str) -> bool:
    if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
        logger.warning(
            "Detected VLLM_USE_V1=1 with %s. Usage should "
            "be considered experimental. Please report any "
            "issues on Github.", feature_name)
        should_exit = False
    else:
        logger.info(
            "%s is experimental on VLLM_USE_V1=1. "
            "Falling back to V0 Engine.", feature_name)
        should_exit = True
    return should_exit


1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
def human_readable_int(value):
    """Parse human-readable integers like '1k', '2M', etc.
    Including decimal values with decimal multipliers.
    
    Examples:
    - '1k' -> 1,000
    - '1K' -> 1,024
    - '25.6k' -> 25,600
    """
    value = value.strip()
    match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
    if match:
        decimal_multiplier = {
            'k': 10**3,
            'm': 10**6,
            'g': 10**9,
        }
        binary_multiplier = {
            'K': 2**10,
            'M': 2**20,
            'G': 2**30,
        }

        number, suffix = match.groups()
        if suffix in decimal_multiplier:
            mult = decimal_multiplier[suffix]
            return int(float(number) * mult)
        elif suffix in binary_multiplier:
            mult = binary_multiplier[suffix]
            # Do not allow decimals with binary multipliers
            try:
                return int(number) * mult
            except ValueError as e:
                raise argparse.ArgumentTypeError("Decimals are not allowed " \
                f"with binary suffixes like {suffix}. Did you mean to use " \
                f"{number}{suffix.lower()} instead?") from e

    # Regular plain number.
    return int(value)


1795
1796
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
1797
    return EngineArgs.add_cli_args(FlexibleArgumentParser())
1798
1799
1800


def _async_engine_args_parser():
1801
    return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
1802
                                        async_args_only=True)