config.py 87.2 KB
Newer Older
1
import enum
2
import json
3
from dataclasses import dataclass, field
4
5
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
                    Mapping, Optional, Set, Tuple, Type, Union)
6
7

import torch
8
from transformers import PretrainedConfig
9

10
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
from vllm.model_executor.models import ModelRegistry
14
from vllm.platforms import current_platform
15
from vllm.tracing import is_otel_available, otel_import_error_traceback
16
from vllm.transformers_utils.config import (ConfigFormat, get_config,
17
18
                                            get_hf_image_processor_config,
                                            get_hf_text_config)
19
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
20
                        print_warning_once)
21

22
23
24
if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

25
    from vllm.executor.executor_base import ExecutorBase
26
    from vllm.model_executor.model_loader.loader import BaseModelLoader
27
28
    from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
        BaseTokenizerGroup)
29

30
31
logger = init_logger(__name__)

32
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
33
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
34

35
36
37
38
TaskOption = Literal["auto", "generate", "embedding"]

# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]
39

40
41

class ModelConfig:
42
43
44
45
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
46
            It is also used as the content for `model_name` tag in metrics
47
48
49
50
51
            output when `served_model_name` is not specified.
        task: The task to use the model for. Each vLLM instance only supports
            one task, even if the same model can be used for multiple tasks.
            When the model only supports one task, "auto" can be used to select
            it; otherwise, you must specify explicitly which task to use.
52
        tokenizer: Name or path of the huggingface tokenizer to use.
53
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
54
55
            available, "slow" will always use the slow tokenizer, and
            "mistral" will always use the tokenizer from `mistral_common`.
56
57
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
58
59
60
61
        allowed_local_media_path: Allowing API requests to read local images or
            videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
62
63
64
65
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
66
67
68
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
69
        code_revision: The specific revision to use for the model code on
70
            Hugging Face Hub. It can be a branch name, a tag name, or a
71
            commit id. If unspecified, will use the default version.
72
73
74
        rope_scaling: Dictionary containing the scaling configuration for the
            RoPE embeddings. When using this flag, don't update
            `max_position_embeddings` to the expected new maximum.
75
76
77
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
78
79
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
80
81
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
82
83
        quantization_param_path: Path to JSON file containing scaling factors.
            Used to load KV cache scaling factors into the model when KV cache
84
85
            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
            be used to load activation and weight scaling factors when the
86
            model dtype is FP8_E4M3 on ROCm.
87
88
89
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
90
            If None, the user did not specify, so default to False.
91
92
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
93
94
95
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
96
97
98
99
        disable_sliding_window: Whether to disable sliding window. If True,
            we will disable the sliding window functionality of the model.
            If the model does not support sliding window, this argument is
            ignored.
100
101
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer.
102
        served_model_name: The model name used in metrics tag `model_name`,
103
104
            matches the model name exposed via the APIs. If multiple model
            names provided, the first name will be used. If not specified,
105
            the model name will be the same as `model`.
106
        limit_mm_per_prompt: Maximum number of data instances per modality
107
            per prompt. Only applicable for multimodal models.
108
109
110
111
        override_neuron_config: Initialize non default neuron config or
            override default neuron config that are specific to Neuron devices,
            this argument will be used to configure the neuron config that
            can not be gathered from the vllm arguments.
112
113
        config_format: The config format which shall be loaded.
            Defaults to 'auto' which defaults to 'hf'.
114
115
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor.
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        pooling_type: Used to configure the pooling method in the embedding 
            model.
        pooling_norm: Used to determine whether to normalize the pooled 
            data in the embedding model.
        pooling_softmax: Used to determine whether to softmax the pooled 
            data in the embedding model.
        pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates 
            that the score corresponding to the pooling_step_tag_id in the 
            generated sentence should be returned. Otherwise, it returns 
            the scores for all tokens.
        pooling_returned_token_ids: pooling_returned_token_ids represents a 
            list of indices for the vocabulary dimensions to be extracted, 
            such as the token IDs of good_token and bad_token in the 
            math-shepherd-mistral-7b-prm model.
130
    """
131

132
133
134
135
136
137
138
139
140
    def __init__(
            self,
            model: str,
            task: Union[TaskOption, _Task],
            tokenizer: str,
            tokenizer_mode: str,
            trust_remote_code: bool,
            dtype: Union[str, torch.dtype],
            seed: int,
141
            allowed_local_media_path: str = "",
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
            revision: Optional[str] = None,
            code_revision: Optional[str] = None,
            rope_scaling: Optional[dict] = None,
            rope_theta: Optional[float] = None,
            tokenizer_revision: Optional[str] = None,
            max_model_len: Optional[int] = None,
            spec_target_max_model_len: Optional[int] = None,
            quantization: Optional[str] = None,
            quantization_param_path: Optional[str] = None,
            enforce_eager: Optional[bool] = None,
            max_seq_len_to_capture: Optional[int] = None,
            max_logprobs: int = 20,
            disable_sliding_window: bool = False,
            skip_tokenizer_init: bool = False,
            served_model_name: Optional[Union[str, List[str]]] = None,
            limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
            use_async_output_proc: bool = True,
            override_neuron_config: Optional[Dict[str, Any]] = None,
            config_format: ConfigFormat = ConfigFormat.AUTO,
            chat_template_text_format: str = "string",
            mm_processor_kwargs: Optional[Dict[str, Any]] = None,
            pooling_type: Optional[str] = None,
            pooling_norm: Optional[bool] = None,
            pooling_softmax: Optional[bool] = None,
            pooling_step_tag_id: Optional[int] = None,
            pooling_returned_token_ids: Optional[List[int]] = None) -> None:
168
        self.model = model
169
        self.tokenizer = tokenizer
170
        self.tokenizer_mode = tokenizer_mode
171
        self.trust_remote_code = trust_remote_code
172
        self.allowed_local_media_path = allowed_local_media_path
173
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
174
        self.revision = revision
175
        self.code_revision = code_revision
176
        self.rope_scaling = rope_scaling
177
        self.rope_theta = rope_theta
178
179
180
181
182
        # The tokenizer version is consistent with the model version by default.
        if tokenizer_revision is None:
            self.tokenizer_revision = revision
        else:
            self.tokenizer_revision = tokenizer_revision
183
        self.quantization = quantization
184
        self.quantization_param_path = quantization_param_path
185
        self.enforce_eager = enforce_eager
186
        self.max_seq_len_to_capture = max_seq_len_to_capture
187
        self.max_logprobs = max_logprobs
188
        self.disable_sliding_window = disable_sliding_window
189
        self.skip_tokenizer_init = skip_tokenizer_init
190

191
        self.hf_config = get_config(self.model, trust_remote_code, revision,
192
193
                                    code_revision, rope_scaling, rope_theta,
                                    config_format)
194
        self.hf_text_config = get_hf_text_config(self.hf_config)
195
196
        self.hf_image_processor_config = get_hf_image_processor_config(
            self.model, revision)
197
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
198
        self.use_async_output_proc = use_async_output_proc
199
        self.chat_template_text_format = chat_template_text_format
200
        self.mm_processor_kwargs = mm_processor_kwargs
Woosuk Kwon's avatar
Woosuk Kwon committed
201

202
203
        # Set enforce_eager to False if the value is unset.
        if self.enforce_eager is None:
204
205
            self.enforce_eager = False

206
207
208
209
210
211
212
213
214
        sliding_window = getattr(self.hf_text_config, "sliding_window", None)
        has_interleaved_attention = (sliding_window is not None) and (
            isinstance(sliding_window, list) or
            (self.hf_text_config.model_type in ["gemma2"]))

        if (not self.disable_sliding_window and has_interleaved_attention):
            sliding_window_len_min = get_min_sliding_window(
                self.hf_text_config.sliding_window)

Woosuk Kwon's avatar
Woosuk Kwon committed
215
            print_warning_once(
216
                f"{self.hf_text_config.model_type} has interleaved attention, "
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
                "which is currently not supported by vLLM. Disabling sliding "
                "window and capping the max length to the sliding window size "
219
                f"({sliding_window_len_min}).")
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
            self.disable_sliding_window = True

222
223
224
225
        self.max_model_len = _get_and_verify_max_len(
            hf_config=self.hf_text_config,
            max_model_len=max_model_len,
            disable_sliding_window=self.disable_sliding_window,
226
227
            sliding_window_len=self.get_hf_config_sliding_window(),
            spec_target_max_model_len=spec_target_max_model_len)
228
229
        self.served_model_name = get_served_model_name(model,
                                                       served_model_name)
230
231
        self.multimodal_config = self._init_multimodal_config(
            limit_mm_per_prompt)
232
233
        if not self.skip_tokenizer_init:
            self._verify_tokenizer_mode()
234

235
236
237
        self.is_attention_free = self._init_attention_free()
        self.has_inner_state = self._init_has_inner_state()

238
239
240
241
        if current_platform.is_neuron():
            self.override_neuron_config = override_neuron_config
        else:
            self.override_neuron_config = None
242
243
244
245

        supported_tasks, task = self._resolve_task(task, self.hf_config)
        self.supported_tasks = supported_tasks
        self.task: Final = task
246
247
248
249
250
251
252
        self.pooler_config = self._init_pooler_config(
            pooling_type,
            pooling_norm,
            pooling_softmax,
            pooling_step_tag_id,
            pooling_returned_token_ids,
        )
253

254
        self._verify_quantization()
255
        self._verify_cuda_graph()
256
        self._verify_bnb_config()
257

258
259
260
261
    def _init_multimodal_config(
        self, limit_mm_per_prompt: Optional[Mapping[str, int]]
    ) -> Optional["MultiModalConfig"]:
        architectures = getattr(self.hf_config, "architectures", [])
262
        if ModelRegistry.is_multimodal_model(architectures):
263
            return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
264
265
266
267
268
269

        if limit_mm_per_prompt:
            raise ValueError("`limit_mm_per_prompt` is only supported for "
                             "multimodal models.")

        return None
270

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    def _init_pooler_config(
        self,
        pooling_type: Optional[str] = None,
        pooling_norm: Optional[bool] = None,
        pooling_softmax: Optional[bool] = None,
        pooling_step_tag_id: Optional[int] = None,
        pooling_returned_token_ids: Optional[List[int]] = None
    ) -> Optional["PoolerConfig"]:
        if self.task == "embedding":
            return PoolerConfig(
                pooling_type=pooling_type,
                pooling_norm=pooling_norm,
                pooling_softmax=pooling_softmax,
                pooling_step_tag_id=pooling_step_tag_id,
                pooling_returned_token_ids=pooling_returned_token_ids)
        return None

288
289
290
291
292
293
294
295
    def _init_attention_free(self) -> bool:
        architectures = getattr(self.hf_config, "architectures", [])
        return ModelRegistry.is_attention_free_model(architectures)

    def _init_has_inner_state(self) -> bool:
        architectures = getattr(self.hf_config, "architectures", [])
        return ModelRegistry.model_has_inner_state(architectures)

296
297
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
298
        if tokenizer_mode not in ["auto", "slow", "mistral"]:
299
300
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
301
                "either 'auto', 'slow' or 'mistral'.")
302
        self.tokenizer_mode = tokenizer_mode
303

304
305
    def _resolve_task(
        self,
306
        task_option: Union[TaskOption, _Task],
307
        hf_config: PretrainedConfig,
308
309
310
311
    ) -> Tuple[Set[_Task], _Task]:
        if task_option == "draft":
            return {"draft"}, "draft"

312
313
        architectures = getattr(hf_config, "architectures", [])

314
        task_support: Dict[_Task, bool] = {
315
316
317
318
319
            # NOTE: Listed from highest to lowest priority,
            # in case the model supports multiple of them
            "generate": ModelRegistry.is_text_generation_model(architectures),
            "embedding": ModelRegistry.is_embedding_model(architectures),
        }
320
        supported_tasks_lst: List[_Task] = [
321
322
323
324
325
326
            task for task, is_supported in task_support.items() if is_supported
        ]
        supported_tasks = set(supported_tasks_lst)

        if task_option == "auto":
            selected_task = next(iter(supported_tasks_lst))
327

328
329
330
331
            if len(supported_tasks) > 1:
                logger.info(
                    "This model supports multiple tasks: %s. "
                    "Defaulting to '%s'.", supported_tasks, selected_task)
332
        else:
333
334
335
336
337
338
339
            if task_option not in supported_tasks:
                msg = (
                    f"This model does not support the '{task_option}' task. "
                    f"Supported tasks: {supported_tasks}")
                raise ValueError(msg)

            selected_task = task_option
340

341
        return supported_tasks, selected_task
342

343
344
345
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
346
            # compressed-tensors uses a "compression_config" key
347
            quant_cfg = getattr(self.hf_config, "compression_config", None)
348
349
        return quant_cfg

350
    def _verify_quantization(self) -> None:
351
        supported_quantization = [*QUANTIZATION_METHODS]
352
353
354
355
        rocm_supported_quantization = [
            "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
            "fbgemm_fp8"
        ]
356
        optimized_quantization_methods = [
357
358
359
            "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
            "awq_marlin", "fbgemm_fp8", "compressed_tensors",
            "compressed-tensors", "experts_int8"
360
        ]
361
        tpu_supported_quantization = ["tpu_int8"]
362
        neuron_supported_quantization = ["neuron_quant"]
363
364
365
366
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
367
368
        quant_cfg = self._parse_quant_hf_config()

369
370
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
371
372

            # Detect which checkpoint is it
373
            for _, method in QUANTIZATION_METHODS.items():
374
375
376
377
378
379
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization)
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break
380

381
            # Verify quantization configurations.
382
            if self.quantization is None:
383
384
                self.quantization = quant_method
            elif self.quantization != quant_method:
385
386
                raise ValueError(
                    "Quantization method specified in the model config "
387
                    f"({quant_method}) does not match the quantization "
388
389
390
391
392
393
394
395
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
396
            if current_platform.is_rocm(
397
            ) and self.quantization not in rocm_supported_quantization:
398
                raise ValueError(
399
400
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm.")
401
            if current_platform.is_tpu(
402
403
404
405
            ) and self.quantization not in tpu_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in TPU Backend.")
406
            if self.quantization not in optimized_quantization_methods:
407
                logger.warning(
408
                    "%s quantization is not fully "
409
                    "optimized yet. The speed can be slower than "
410
                    "non-quantized models.", self.quantization)
411
            if (self.quantization == "awq" and current_platform.is_rocm()
412
413
414
415
416
                    and not envs.VLLM_USE_TRITON_AWQ):
                logger.warning(
                    "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
                    " is not set, enabling VLLM_USE_TRITON_AWQ.")
                envs.VLLM_USE_TRITON_AWQ = True
417
            if current_platform.is_neuron(
418
419
420
421
            ) and self.quantization not in neuron_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in Neuron Backend.")
422

423
    def _verify_cuda_graph(self) -> None:
424
425
426
427
        if self.max_seq_len_to_capture is None:
            self.max_seq_len_to_capture = self.max_model_len
        self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
                                          self.max_model_len)
428

429
430
    def _verify_bnb_config(self) -> None:
        """
431
        The current version of bitsandbytes (0.44.0) with 8-bit models does not
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        yet support CUDA graph.
        """
        is_bitsandbytes = self.quantization == "bitsandbytes"
        has_quantization_config = (getattr(self.hf_config,
                                           "quantization_config", None)
                                   is not None)
        is_8bit = (self.hf_config.quantization_config.get(
            "load_in_8bit", False) if has_quantization_config else False)
        if all([
                is_bitsandbytes,
                has_quantization_config,
                is_8bit,
                not self.enforce_eager,
        ]):
            logger.warning(
                "CUDA graph is not supported on BitAndBytes 8bit yet, "
                "fallback to the eager mode.")
            self.enforce_eager = True

451
452
453
454
455
456
457
458
459
460
461
462
    def verify_async_output_proc(self, parallel_config, speculative_config,
                                 device_config) -> None:
        if not self.use_async_output_proc:
            # Nothing to check
            return

        if parallel_config.pipeline_parallel_size > 1:
            logger.warning("Async output processing can not be enabled "
                           "with pipeline parallel")
            self.use_async_output_proc = False
            return

463
464
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
465
        if device_config.device_type not in ("cuda", "tpu", "xpu"):
466
            logger.warning(
467
                "Async output processing is only supported for CUDA, TPU, XPU. "
468
                "Disabling it for other platforms.")
469
470
471
472
473
474
475
476
477
            self.use_async_output_proc = False
            return

        if envs.VLLM_USE_RAY_SPMD_WORKER:
            logger.warning(
                "Async output processing can not be enabled with ray spmd")
            self.use_async_output_proc = False
            return

478
479
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
480
        if device_config.device_type == "cuda" and self.enforce_eager:
481
482
483
484
485
486
487
488
489
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            self.use_async_output_proc = not self.enforce_eager
            return

        # Async postprocessor is not necessary with embedding mode
        # since there is no token generation
490
        if self.task == "embedding":
491
492
            self.use_async_output_proc = False

493
494
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
495
496
497
498
499
        if speculative_config:
            logger.warning("Async output processing is not supported with"
                           " speculative decoding currently.")
            self.use_async_output_proc = False

500
501
502
503
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
504
505
        total_num_attention_heads = getattr(self.hf_text_config,
                                            "num_attention_heads", 0)
506
507
508
509
510
511
512
513
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        pipeline_parallel_size = parallel_config.pipeline_parallel_size
514
515
516
517
518
519
520
521
522
523
524
        if pipeline_parallel_size > 1:
            architectures = getattr(self.hf_config, "architectures", [])
            if not ModelRegistry.is_pp_supported_model(architectures):
                raise NotImplementedError(
                    "Pipeline parallelism is not supported for this model. "
                    "Supported models implement the `SupportsPP` interface.")

            if self.use_async_output_proc:
                logger.warning("Async output processor is not supported with "
                               "pipeline parallelism currently. Disabling it.")
                self.use_async_output_proc = False
525

526
527
    def get_hf_config_sliding_window(
            self) -> Union[Optional[int], List[Optional[int]]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
528
        """Get the sliding window size, or None if disabled."""
529
530
531
532

        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
        # addition to sliding window size. We check if that field is present
        # and if it's False, return None.
533
534
        if (hasattr(self.hf_text_config, "use_sliding_window")
                and not self.hf_text_config.use_sliding_window):
535
            return None
536
        return getattr(self.hf_text_config, "sliding_window", None)
537

538
    def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
539
540
541
542
543
544
545
546
        """Get the sliding window size, or None if disabled.
        """
        # If user disables sliding window, return None.
        if self.disable_sliding_window:
            return None
        # Otherwise get the value from the hf config.
        return self.get_hf_config_sliding_window()

547
    def get_vocab_size(self) -> int:
548
        return self.hf_text_config.vocab_size
549

550
    def get_hidden_size(self) -> int:
551
        return self.hf_text_config.hidden_size
552
553

    def get_head_size(self) -> int:
wangding zeng's avatar
wangding zeng committed
554
555
556
557
558
559
        # TODO remove hard code
        if hasattr(self.hf_text_config, "model_type"
                   ) and self.hf_text_config.model_type == 'deepseek_v2':
            # FlashAttention supports only head_size 32, 64, 128, 256,
            # we need to pad head_size 192 to 256
            return 256
560
561
562
563

        if self.is_attention_free:
            return 0

564
565
        if hasattr(self.hf_text_config, "head_dim"):
            return self.hf_text_config.head_dim
566
        # FIXME(woosuk): This may not be true for all models.
567
568
        return (self.hf_text_config.hidden_size //
                self.hf_text_config.num_attention_heads)
569

570
571
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
572
        # For GPTBigCode & Falcon:
573
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
574
575
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
576
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
577
        new_decoder_arch_falcon = (
578
            self.hf_config.model_type in falcon_model_types
579
            and getattr(self.hf_config, "new_decoder_architecture", False))
580
        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
581
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
582
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
583
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
584
            return 1
585

586
        # For DBRX and MPT
587
588
589
590
591
        if self.hf_config.model_type == "mpt":
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type == "dbrx":
592
593
594
            return getattr(self.hf_config.attn_config, "kv_n_heads",
                           self.hf_config.num_attention_heads)

595
596
597
        if self.is_attention_free:
            return 0

598
599
600
601
602
603
604
605
606
607
        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
608
            num_kv_heads = getattr(self.hf_text_config, attr, None)
609
610
611
612
613
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
614
        return self.hf_text_config.num_attention_heads
615
616
617
618
619
620
621
622
623
624

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
625

626
627
    def get_num_attention_heads(self,
                                parallel_config: "ParallelConfig") -> int:
628
629
        num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
        return num_heads // parallel_config.tensor_parallel_size
630

631
    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
632
        from vllm.distributed.utils import get_pp_indices
Mor Zusman's avatar
Mor Zusman committed
633
634
        total_num_hidden_layers = getattr(self.hf_text_config,
                                          "num_hidden_layers", 0)
635
636
637
638
        pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
        pp_size = parallel_config.pipeline_parallel_size
        start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
        return end - start
639

640
641
642
643
    def get_num_attention_layers(self,
                                 parallel_config: "ParallelConfig") -> int:
        if self.is_attention_free:
            return 0
Mor Zusman's avatar
Mor Zusman committed
644
645
646

        num_layers = self.get_num_layers(parallel_config)

647
648
649
650
        # Transformers supports layers_block_type @property
        layers = getattr(self.hf_config, "layers_block_type",
                         ["attention"] * num_layers)
        return len([t for t in layers if t == "attention"])
Mor Zusman's avatar
Mor Zusman committed
651

652
653
654
655
656
657
658
659
660
661
662
663
    def get_multimodal_config(self) -> "MultiModalConfig":
        """
        Get the multimodal configuration of the model.

        Raises:
            ValueError: If the model is not multimodal.
        """
        if self.multimodal_config is None:
            raise ValueError("The model is not multimodal.")

        return self.multimodal_config

664
665
666
    @property
    def is_encoder_decoder_model(self) -> bool:
        """Extract the HF encoder/decoder model flag."""
667
668
669
        return getattr(self.hf_config, "is_encoder_decoder", False) or (
            (hasattr(self.hf_config, "text_config") and getattr(
                self.hf_config.text_config, "is_encoder_decoder", False)))
670

671
672
673
674
    @property
    def is_multimodal_model(self) -> bool:
        return self.multimodal_config is not None

675
676

class CacheConfig:
677
678
679
680
681
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
682
            vLLM execution.
683
        swap_space: Size of the CPU swap space per GPU (in GiB).
684
        cache_dtype: Data type for kv cache storage.
685
        num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
686
            profiled num_gpu_blocks if specified. Does nothing if None.
687
    """
688

689
690
691
692
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
693
        swap_space: float,
694
        cache_dtype: str,
695
        is_attention_free: bool = False,
696
        num_gpu_blocks_override: Optional[int] = None,
697
        sliding_window: Optional[int] = None,
698
        enable_prefix_caching: bool = False,
699
        cpu_offload_gb: float = 0,
700
701
702
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
703
        self.swap_space_bytes = swap_space * GiB_bytes
704
        self.num_gpu_blocks_override = num_gpu_blocks_override
705
        self.cache_dtype = cache_dtype
706
        self.is_attention_free = is_attention_free
707
        self.sliding_window = sliding_window
708
        self.enable_prefix_caching = enable_prefix_caching
709
        self.cpu_offload_gb = cpu_offload_gb
710

711
        self._verify_args()
712
        self._verify_cache_dtype()
713
        self._verify_prefix_caching()
714
715

        # Will be set after profiling.
716
717
        self.num_gpu_blocks: Optional[int] = None
        self.num_cpu_blocks: Optional[int] = None
718

719
    def metrics_info(self):
720
721
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
722
723
        return {key: str(value) for key, value in self.__dict__.items()}

724
725
726
727
728
729
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

730
731
732
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
733
        elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
734
            logger.info(
735
736
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
737
738
                "Meanwhile, it may cause accuracy drop without a proper "
                "scaling factor")
739
740
741
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

742
743
744
745
746
747
748
749
750
    def _verify_prefix_caching(self) -> None:
        if not self.enable_prefix_caching:
            return

        if self.sliding_window is not None:
            raise NotImplementedError(
                "Prefix caching is not supported with sliding window. "
                "Run with --disable-sliding-window to use prefix caching.")

751
752
753
754
755
756
757
758
759
760
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

761
762
763
        msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
               f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
               "is allocated for the swap space.")
764
765
766
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
767
            logger.warning("Possibly too large swap space. %s", msg)
768

769

770
771
772
@dataclass
class TokenizerPoolConfig:
    """Configuration for the tokenizer pool.
773

774
775
776
777
778
779
780
781
    Args:
        pool_size: Number of tokenizer workers in the pool.
        pool_type: Type of the pool.
        extra_config: Additional config for the pool.
            The way the config will be used depends on the
            pool type.
    """
    pool_size: int
782
    pool_type: Union[str, Type["BaseTokenizerGroup"]]
783
784
785
    extra_config: dict

    def __post_init__(self):
786
787
        if self.pool_type not in ("ray", ) and not isinstance(
                self.pool_type, type):
788
789
790
791
792
793
            raise ValueError(f"Unknown pool type: {self.pool_type}")
        if not isinstance(self.extra_config, dict):
            raise ValueError("extra_config must be a dictionary.")

    @classmethod
    def create_config(
794
795
        cls, tokenizer_pool_size: int,
        tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]],
796
797
798
        tokenizer_pool_extra_config: Optional[Union[str, dict]]
    ) -> Optional["TokenizerPoolConfig"]:
        """Create a TokenizerPoolConfig from the given parameters.
799

800
        If tokenizer_pool_size is 0, return None.
801

802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
        Args:
            tokenizer_pool_size: Number of tokenizer workers in the pool.
            tokenizer_pool_type: Type of the pool.
            tokenizer_pool_extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type. This can be a JSON string (will be parsed).
        """
        if tokenizer_pool_size:
            if isinstance(tokenizer_pool_extra_config, str):
                tokenizer_pool_extra_config_parsed = json.loads(
                    tokenizer_pool_extra_config)
            else:
                tokenizer_pool_extra_config_parsed = (
                    tokenizer_pool_extra_config or {})
            tokenizer_pool_config = cls(tokenizer_pool_size,
                                        tokenizer_pool_type,
                                        tokenizer_pool_extra_config_parsed)
        else:
            tokenizer_pool_config = None
        return tokenizer_pool_config


824
825
826
827
828
829
830
class LoadFormat(str, enum.Enum):
    AUTO = "auto"
    PT = "pt"
    SAFETENSORS = "safetensors"
    NPCACHE = "npcache"
    DUMMY = "dummy"
    TENSORIZER = "tensorizer"
831
    SHARDED_STATE = "sharded_state"
832
    GGUF = "gguf"
833
    BITSANDBYTES = "bitsandbytes"
834
    MISTRAL = "mistral"
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853


@dataclass
class LoadConfig:
    """
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
            "tensorizer" will use CoreWeave's tensorizer library for
                fast weight loading.
854
            "bitsandbytes" will load nf4 type weights.
855
        ignore_patterns: The list of patterns to ignore when loading the model.
856
            Default to "original/**/*" to avoid repeated loading of llama's
857
            checkpoints.
858

859
860
861
862
863
864
    """

    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(
        default_factory=dict)
865
    ignore_patterns: Optional[Union[List[str], str]] = None
866
867
868
869
870
871
872
873

    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = json.loads(
                model_loader_extra_config)
        self._verify_load_format()

874
875
876
877
878
879
880
        if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
            logger.info(
                "Ignoring the following patterns when downloading weights: %s",
                self.ignore_patterns)
        else:
            self.ignore_patterns = ["original/**/*"]

881
882
883
884
885
886
887
888
    def _verify_load_format(self) -> None:
        if not isinstance(self.load_format, str):
            return

        load_format = self.load_format.lower()
        self.load_format = LoadFormat(load_format)

        rocm_not_supported_load_format: List[str] = []
889
890
        if current_platform.is_rocm(
        ) and load_format in rocm_not_supported_load_format:
891
892
893
894
895
896
897
898
899
900
            rocm_supported_load_format = [
                f for f in LoadFormat.__members__
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format '{load_format}' is not supported in ROCm. "
                f"Supported load formats are "
                f"{rocm_supported_load_format}")


901
class ParallelConfig:
902
903
904
905
906
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
907
        worker_use_ray: Deprecated, use distributed_executor_backend instead.
zspo's avatar
zspo committed
908
909
910
        max_parallel_loading_workers: Maximum number of multiple batches
            when load model sequentially. To avoid RAM OOM when using tensor
            parallel and large models.
911
912
        disable_custom_all_reduce: Disable the custom all-reduce kernel and
            fall back to NCCL.
913
914
        tokenizer_pool_config: Config for the tokenizer pool.
            If None, will use synchronous tokenization.
915
916
        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
917
        placement_group: ray distributed model workers placement group.
918
919
920
921
        distributed_executor_backend: Backend to use for distributed model
            workers, either "ray" or "mp" (multiprocessing). If either
            pipeline_parallel_size or tensor_parallel_size is greater than 1,
            will default to "ray" if Ray is installed or "mp" otherwise.
922
    """
923

924
925
926
927
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
928
        worker_use_ray: Optional[bool] = None,
929
        max_parallel_loading_workers: Optional[int] = None,
930
        disable_custom_all_reduce: bool = False,
931
        tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
932
        ray_workers_use_nsight: bool = False,
933
        placement_group: Optional["PlacementGroup"] = None,
934
935
        distributed_executor_backend: Optional[Union[
            str, Type["ExecutorBase"]]] = None,
936
937
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
938
        self.tensor_parallel_size = tensor_parallel_size
939
        self.distributed_executor_backend = distributed_executor_backend
940
        self.max_parallel_loading_workers = max_parallel_loading_workers
941
        self.disable_custom_all_reduce = disable_custom_all_reduce
942
        self.tokenizer_pool_config = tokenizer_pool_config
943
        self.ray_workers_use_nsight = ray_workers_use_nsight
944
        self.placement_group = placement_group
945
        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
946

947
948
949
        if worker_use_ray:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
950
            elif not self.use_ray:
951
952
953
954
                raise ValueError(f"worker-use-ray can't be used with "
                                 f"distributed executor backend "
                                 f"'{self.distributed_executor_backend}'.")

955
956
957
958
959
960
961
        if current_platform.is_tpu() and self.world_size > 1:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
            if self.distributed_executor_backend != "ray":
                raise ValueError(
                    "TPU backend only supports Ray for distributed inference.")

962
        if self.distributed_executor_backend is None and self.world_size > 1:
963
964
965
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

966
            from vllm.executor import ray_utils
967
            backend = "mp"
968
            ray_found = ray_utils.ray_is_available()
969
            if (current_platform.is_cuda()
970
                    and cuda_device_count_stateless() < self.world_size):
971
972
                if not ray_found:
                    raise ValueError("Unable to load Ray which is "
973
974
975
                                     "required for multi-node inference, "
                                     "please install Ray with `pip install "
                                     "ray`.") from ray_utils.ray_import_err
976
977
                backend = "ray"
            elif ray_found:
978
                if self.placement_group:
979
                    backend = "ray"
980
981
982
983
984
985
                else:
                    from ray import is_initialized as ray_is_initialized
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
                        if get_current_placement_group():
                            backend = "ray"
986
987
988
            self.distributed_executor_backend = backend
            logger.info("Defaulting to use %s for distributed inference",
                        backend)
989

990
        self._verify_args()
991
        self.rank: int = 0
992

993
994
995
996
997
998
    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
            and self.distributed_executor_backend.uses_ray)

999
    def _verify_args(self) -> None:
1000
1001
1002
1003
1004
1005
1006
        # Lazy import to avoid circular import
        from vllm.executor.executor_base import ExecutorBase

        if self.distributed_executor_backend not in (
                "ray", "mp", None) and not (isinstance(
                    self.distributed_executor_backend, type) and issubclass(
                        self.distributed_executor_backend, ExecutorBase)):
1007
            raise ValueError(
1008
1009
1010
1011
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
                "values are 'ray', 'mp' or custom ExecutorBase subclass.")
        if self.use_ray:
1012
1013
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()
1014
        if current_platform.is_rocm():
1015
1016
1017
1018
            self.disable_custom_all_reduce = True
            logger.info(
                "Disabled the custom all-reduce kernel because it is not "
                "supported on AMD GPUs.")
1019
        if self.ray_workers_use_nsight and not self.use_ray:
1020
1021
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
1022

1023
1024

class SchedulerConfig:
1025
1026
1027
    """Scheduler configuration.

    Args:
1028
        task: The task to use the model for.
1029
1030
1031
1032
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
1033
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
1034
            and generated text).
1035
1036
1037
1038
        num_lookahead_slots: The number of slots to allocate per sequence per
            step, beyond the known token ids. This is used in speculative
            decoding to store KV activations of tokens which may or may not be
            accepted.
1039
1040
        delay_factor: Apply a delay (of delay factor multiplied by previous
            prompt latency) before scheduling next prompt.
1041
1042
        enable_chunked_prefill: If True, prefill requests can be chunked based
            on the remaining max_num_batched_tokens.
1043
        preemption_mode: Whether to perform preemption by swapping or
1044
1045
1046
1047
1048
            recomputation. If not specified, we determine the mode as follows:
            We use recomputation by default since it incurs lower overhead than
            swapping. However, when the sequence group has multiple sequences
            (e.g., beam search), recomputation is not currently supported. In
            such a case, we use swapping instead.
1049
1050
1051
1052
        send_delta_data: Private API. If used, scheduler sends delta data to
            workers instead of an entire data. It should be enabled only
            when SPMD worker architecture is enabled. I.e.,
            VLLM_USE_RAY_SPMD_WORKER=1
1053
        policy: The scheduling policy to use. "fcfs" (default) or "priority".
1054
    """
1055

1056
    def __init__(self,
1057
                 task: _Task,
1058
1059
1060
1061
1062
1063
                 max_num_batched_tokens: Optional[int],
                 max_num_seqs: int,
                 max_model_len: int,
                 num_lookahead_slots: int = 0,
                 delay_factor: float = 0.0,
                 enable_chunked_prefill: bool = False,
1064
                 is_multimodal_model: bool = False,
1065
                 preemption_mode: Optional[str] = None,
1066
                 num_scheduler_steps: int = 1,
1067
                 multi_step_stream_outputs: bool = False,
1068
1069
                 send_delta_data: bool = False,
                 policy: str = "fcfs") -> None:
1070
        if max_num_batched_tokens is None:
1071
            if enable_chunked_prefill:
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
                if num_scheduler_steps > 1:
                    # Multi-step Chunked-Prefill doesn't allow prompt-chunking
                    # for now. Have max_num_batched_tokens set to max_model_len
                    # so we don't reject sequences on account of a short
                    # max_num_batched_tokens.
                    max_num_batched_tokens = max(max_model_len, 2048)
                else:
                    # It is the values that have the best balance between ITL
                    # and TTFT on A100. Note it is not optimized for throughput.
                    max_num_batched_tokens = 512
1082
1083
1084
            else:
                # If max_model_len is too short, use 2048 as the default value
                # for higher throughput.
1085
1086
                max_num_batched_tokens = max(max_model_len, 2048)

1087
            if task == "embedding":
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
                # For embedding, choose specific value for higher throughput
                max_num_batched_tokens = max(
                    max_num_batched_tokens,
                    _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
                )
            if is_multimodal_model:
                # The value needs to be at least the number of multimodal tokens
                max_num_batched_tokens = max(
                    max_num_batched_tokens,
                    _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
                )

        self.max_num_batched_tokens = max_num_batched_tokens

1102
        if enable_chunked_prefill:
1103
1104
            logger.info(
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
1105
                self.max_num_batched_tokens)
1106

1107
        self.task: Final = task
1108
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
1109
        self.max_model_len = max_model_len
1110
1111
        self.num_lookahead_slots = num_lookahead_slots
        self.delay_factor = delay_factor
1112
        self.chunked_prefill_enabled = enable_chunked_prefill
1113
        self.preemption_mode = preemption_mode
1114
        self.num_scheduler_steps = num_scheduler_steps
1115
        self.multi_step_stream_outputs = multi_step_stream_outputs
1116
        self.send_delta_data = send_delta_data
1117
        self.policy = policy
1118
1119
1120
        self._verify_args()

    def _verify_args(self) -> None:
1121
1122
        if (self.max_num_batched_tokens < self.max_model_len
                and not self.chunked_prefill_enabled):
1123
1124
1125
1126
1127
1128
1129
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
1130

1131
1132
1133
1134
1135
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
1136

1137
1138
1139
1140
1141
1142
        if self.num_lookahead_slots < 0:
            raise ValueError(
                "num_lookahead_slots "
                f"({self.num_lookahead_slots}) must be greater than or "
                "equal to 0.")

1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
        if self.num_scheduler_steps < 1:
            raise ValueError(
                "num_scheduler_steps "
                f"({self.num_scheduler_steps}) must be greater than or "
                "equal to 1.")

    @property
    def is_multi_step(self) -> bool:
        return self.num_scheduler_steps > 1

1153

1154
class DeviceConfig:
1155
    device: Optional[torch.device]
1156

1157
1158
1159
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
1160
1161
            if current_platform.is_cuda_alike():
                self.device_type = "cuda"
1162
            elif current_platform.is_neuron():
1163
                self.device_type = "neuron"
1164
            elif current_platform.is_openvino():
1165
                self.device_type = "openvino"
1166
            elif current_platform.is_tpu():
1167
                self.device_type = "tpu"
1168
            elif current_platform.is_cpu():
1169
                self.device_type = "cpu"
1170
            elif current_platform.is_xpu():
1171
                self.device_type = "xpu"
1172
            else:
1173
                raise RuntimeError("Failed to infer device type")
1174
1175
1176
1177
1178
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
1179
        if self.device_type in ["neuron", "openvino"]:
1180
            self.device = torch.device("cpu")
1181
1182
        elif self.device_type in ["tpu"]:
            self.device = None
1183
1184
1185
1186
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

1187

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
class SpeculativeConfig:
    """Configuration for speculative decoding.

    The configuration is currently specialized to draft-model speculative
    decoding with top-1 proposals.
    """

    @staticmethod
    def maybe_create_spec_config(
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        target_dtype: str,
        speculative_model: Optional[str],
1201
        speculative_model_quantization: Optional[str],
1202
        speculative_draft_tensor_parallel_size: Optional[int],
1203
        num_speculative_tokens: Optional[int],
1204
        speculative_disable_mqa_scorer: Optional[bool],
1205
1206
        speculative_max_model_len: Optional[int],
        enable_chunked_prefill: bool,
1207
        disable_log_stats: bool,
1208
        speculative_disable_by_batch_size: Optional[int],
1209
1210
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1211
1212
1213
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: Optional[float],
        typical_acceptance_sampler_posterior_alpha: Optional[float],
1214
        disable_logprobs: Optional[bool],
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
    ) -> Optional["SpeculativeConfig"]:
        """Create a SpeculativeConfig if possible, else return None.

        This function attempts to create a SpeculativeConfig object based on the
        provided parameters. If the necessary conditions are met, it returns an
        instance of SpeculativeConfig. Otherwise, it returns None.

        Args:
            target_model_config (ModelConfig): The configuration of the target
                model.
            target_parallel_config (ParallelConfig): The parallel configuration
                for the target model.
            target_dtype (str): The data type used for the target model.
            speculative_model (Optional[str]): The name of the speculative
                model, if provided.
1230
1231
1232
            speculative_model_quantization (Optional[str]): Quantization method
                that was used to quantize the speculative model weights. If
                None, we assume the model weights are not quantized.
1233
1234
            speculative_draft_tensor_parallel_size (Optional[int]): The degree
                of the tensor parallelism for the draft model.
1235
            num_speculative_tokens (Optional[int]): The number of speculative
1236
1237
                tokens, if provided. Will default to the number in the draft
                model config if present, otherwise is required.
1238
1239
1240
            speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
                scorer for the speculative model and fall back to batch
                expansion for scoring.
1241
1242
1243
1244
1245
1246
            speculative_max_model_len (Optional[int]): The maximum model len of
                the speculative model. Used when testing the ability to skip
                speculation for some sequences.
            enable_chunked_prefill (bool): Whether vLLM is configured to use
                chunked prefill or not. Used for raising an error since its not
                yet compatible with spec decode.
1247
1248
1249
            speculative_disable_by_batch_size (Optional[int]): Disable
                speculative decoding for new incoming requests when the number
                of enqueue requests  is larger than this value, if provided.
1250
1251
1252
1253
            ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
                window, if provided.
            ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
                window, if provided.
1254
1255
1256
1257
1258
1259
1260
1261
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
1262
                accepted. This threshold is used only when we use the
1263
1264
1265
1266
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1267
1268
1269
1270
1271
            disable_logprobs (Optional[bool]): If set to True, token log
                probabilities are not returned during speculative decoding.
                If set to False, token log probabilities are returned
                according to the log probability settings in SamplingParams.
                If not specified, it defaults to True.
1272

1273
1274
1275
1276
1277
        Returns:
            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                the necessary conditions are met, else None.
        """

1278
1279
1280
1281
        if speculative_model is None:
            if num_speculative_tokens is not None:
                raise ValueError("num_speculative_tokens was provided without "
                                 "speculative_model.")
1282
1283
            return None

1284
1285
1286
1287
1288
1289
        if (speculative_disable_by_batch_size is not None
                and speculative_disable_by_batch_size < 2):
            raise ValueError("Expect the batch size threshold of disabling "
                             "speculative decoding is > 1, but got "
                             f"{speculative_disable_by_batch_size=}")

1290
1291
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
1292
1293
1294
1295
1296
        if enable_chunked_prefill:
            raise ValueError(
                "Speculative decoding and chunked prefill are "
                f"currently mutually exclusive ({enable_chunked_prefill=}).")

1297
1298
        # TODO: The user should be able to specify revision/max model len
        # for the draft model. It is not currently supported.
1299
1300
        draft_revision = None
        draft_code_revision = None
1301
        draft_quantization = speculative_model_quantization
1302

1303
1304
        if speculative_model == "[ngram]":
            if ngram_prompt_lookup_min is None:
1305
1306
1307
1308
1309
1310
1311
1312
                ngram_prompt_lookup_min = 1
            if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
                raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
            if ngram_prompt_lookup_min < 1:
                raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
            if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
                raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
                                 f"larger than {ngram_prompt_lookup_max=}")
1313

1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            draft_model_config = target_model_config
            draft_parallel_config = target_parallel_config
        else:
            ngram_prompt_lookup_max = 0
            ngram_prompt_lookup_min = 0
            draft_model_config = ModelConfig(
                model=speculative_model,
1324
                task="draft",
1325
1326
1327
                tokenizer=target_model_config.tokenizer,
                tokenizer_mode=target_model_config.tokenizer_mode,
                trust_remote_code=target_model_config.trust_remote_code,
1328
1329
                allowed_local_media_path=target_model_config.
                allowed_local_media_path,
1330
1331
1332
1333
1334
1335
                dtype=target_model_config.dtype,
                seed=target_model_config.seed,
                revision=draft_revision,
                code_revision=draft_code_revision,
                tokenizer_revision=target_model_config.tokenizer_revision,
                max_model_len=None,
1336
                spec_target_max_model_len=target_model_config.max_model_len,
1337
1338
                quantization=draft_quantization,
                enforce_eager=target_model_config.enforce_eager,
1339
1340
                max_seq_len_to_capture=target_model_config.
                max_seq_len_to_capture,
1341
1342
1343
                max_logprobs=target_model_config.max_logprobs,
            )

1344
            draft_hf_config = draft_model_config.hf_config
1345

1346
1347
1348
1349
1350
            if (num_speculative_tokens is not None
                    and hasattr(draft_hf_config, "num_lookahead_tokens")):
                draft_hf_config.num_lookahead_tokens = num_speculative_tokens

            n_predict = getattr(draft_hf_config, "n_predict", None)
1351
1352
1353
1354
1355
1356
1357
1358
            if n_predict is not None:
                if num_speculative_tokens is None:
                    # Default to max value defined in draft model config.
                    num_speculative_tokens = n_predict
                elif num_speculative_tokens > n_predict:
                    # Verify provided value doesn't exceed the maximum
                    # supported by the draft model.
                    raise ValueError(
1359
1360
1361
                        "This speculative model supports a maximum of "
                        f"num_speculative_tokens={n_predict}, but "
                        f"{num_speculative_tokens=} was provided.")
1362

1363
1364
1365
1366
1367
1368
1369
1370
1371
            draft_model_config.max_model_len = (
                SpeculativeConfig._maybe_override_draft_max_model_len(
                    speculative_max_model_len,
                    draft_model_config.max_model_len,
                    target_model_config.max_model_len,
                ))

            draft_parallel_config = (
                SpeculativeConfig.create_draft_parallel_config(
1372
                    target_parallel_config,
1373
                    speculative_draft_tensor_parallel_size, draft_hf_config))
1374

1375
1376
1377
1378
1379
1380
        if num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative_model unless the draft model config contains an "
                "n_predict parameter.")

1381
1382
1383
1384
        if typical_acceptance_sampler_posterior_threshold is None:
            typical_acceptance_sampler_posterior_threshold = 0.09
        if typical_acceptance_sampler_posterior_alpha is None:
            typical_acceptance_sampler_posterior_alpha = 0.3
1385
1386
        if disable_logprobs is None:
            disable_logprobs = True
1387

1388
1389
1390
1391
        return SpeculativeConfig(
            draft_model_config,
            draft_parallel_config,
            num_speculative_tokens,
1392
            speculative_disable_mqa_scorer,
1393
            speculative_disable_by_batch_size,
1394
1395
            ngram_prompt_lookup_max,
            ngram_prompt_lookup_min,
1396
1397
1398
1399
1400
            draft_token_acceptance_method=draft_token_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=\
                typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=\
                typical_acceptance_sampler_posterior_alpha,
1401
1402
            disable_logprobs=disable_logprobs,
            disable_log_stats=disable_log_stats,
1403
1404
        )

1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
    @staticmethod
    def _maybe_override_draft_max_model_len(
        speculative_max_model_len: Optional[int],
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:

            if speculative_max_model_len > draft_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {draft_max_model_len=}")

            if speculative_max_model_len > target_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {target_max_model_len=}")

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

1440
1441
    @staticmethod
    def create_draft_parallel_config(
1442
        target_parallel_config: ParallelConfig,
1443
1444
        speculative_draft_tensor_parallel_size: Optional[int],
        draft_hf_config: PretrainedConfig,
1445
    ) -> ParallelConfig:
1446
1447
        """Create a parallel config for use by the draft worker.

1448
        This is mostly a copy of the target parallel config, except the tp_size.
1449
        """
1450
        if speculative_draft_tensor_parallel_size is None:
1451
1452
1453
1454
1455
1456
1457
1458
1459
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "MLPSpeculator cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1")
            else:
                speculative_draft_tensor_parallel_size = \
                    target_parallel_config.tensor_parallel_size
1460
1461
        elif speculative_draft_tensor_parallel_size not in (
                1, target_parallel_config.tensor_parallel_size):
1462
            raise ValueError(
1463
                f"{speculative_draft_tensor_parallel_size=} cannot be "
1464
                f"other value than 1 or target model tensor_parallel_size")
1465

1466
1467
1468
        draft_parallel_config = ParallelConfig(
            pipeline_parallel_size=target_parallel_config.
            pipeline_parallel_size,
1469
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
1470
1471
            distributed_executor_backend=target_parallel_config.
            distributed_executor_backend,
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
            max_parallel_loading_workers=target_parallel_config.
            max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.
            disable_custom_all_reduce,
            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
            ray_workers_use_nsight=target_parallel_config.
            ray_workers_use_nsight,
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

    def __init__(
        self,
        draft_model_config: ModelConfig,
        draft_parallel_config: ParallelConfig,
        num_speculative_tokens: int,
1489
        speculative_disable_mqa_scorer: Optional[bool],
1490
1491
1492
        speculative_disable_by_batch_size: Optional[int],
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1493
1494
1495
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
1496
        disable_logprobs: bool,
1497
        disable_log_stats: bool,
1498
1499
1500
1501
1502
1503
1504
1505
    ):
        """Create a SpeculativeConfig object.

        Args:
            draft_model_config: ModelConfig for the draft model.
            draft_parallel_config: ParallelConfig for the draft model.
            num_speculative_tokens: The number of tokens to sample from the
                draft model before scoring with the target model.
1506
1507
1508
1509
1510
            speculative_disable_by_batch_size: Disable speculative
                decoding for new incoming requests when the number of
                enqueue requests is larger than this value.
            ngram_prompt_lookup_max: Max size of ngram token window.
            ngram_prompt_lookup_min: Min size of ngram token window.
1511
1512
1513
1514
1515
1516
1517
1518
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
1519
                accepted. This threshold is used only when we use the
1520
1521
1522
1523
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1524
            disable_logprobs: If set to True, token log probabilities will not
1525
                be returned even if requested by sampling parameters. This
1526
1527
1528
1529
                reduces latency by skipping logprob calculation in proposal
                sampling, target sampling, and after accepted tokens are
                determined. If set to False, log probabilities will be
                returned.
1530
1531
            disable_log_stats: Whether to disable periodic printing of stage
                times in speculative decoding.
1532
1533
1534
1535
        """
        self.draft_model_config = draft_model_config
        self.draft_parallel_config = draft_parallel_config
        self.num_speculative_tokens = num_speculative_tokens
1536
        self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
1537
1538
1539
1540
        self.speculative_disable_by_batch_size = \
            speculative_disable_by_batch_size
        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
1541
1542
1543
1544
1545
        self.draft_token_acceptance_method = draft_token_acceptance_method
        self.typical_acceptance_sampler_posterior_threshold = \
            typical_acceptance_sampler_posterior_threshold
        self.typical_acceptance_sampler_posterior_alpha = \
            typical_acceptance_sampler_posterior_alpha
1546
        self.disable_logprobs = disable_logprobs
1547
        self.disable_log_stats = disable_log_stats
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558

        self._verify_args()

    def _verify_args(self) -> None:
        if self.num_speculative_tokens <= 0:
            raise ValueError("Expected num_speculative_tokens to be greater "
                             f"than zero ({self.num_speculative_tokens}).")

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
                self.draft_parallel_config)
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
            # Validate and set draft token acceptance related settings.

        if (self.draft_token_acceptance_method is None):
            raise ValueError("draft_token_acceptance_method is not set. "
                             "Expected values are rejection_sampler or "
                             "typical_acceptance_sampler.")

        if (self.draft_token_acceptance_method != 'rejection_sampler'
                and self.draft_token_acceptance_method !=
                'typical_acceptance_sampler'):
            raise ValueError(
                "Expected draft_token_acceptance_method to be either "
                "rejection_sampler or typical_acceptance_sampler. Instead it "
                f"is {self.draft_token_acceptance_method}")

        if (self.typical_acceptance_sampler_posterior_threshold < 0
                or self.typical_acceptance_sampler_posterior_alpha < 0):
            raise ValueError(
                "Expected typical_acceptance_sampler_posterior_threshold "
                "and typical_acceptance_sampler_posterior_alpha to be > 0. "
                "Instead found "
                f"typical_acceptance_sampler_posterior_threshold = "
                f"{self.typical_acceptance_sampler_posterior_threshold} and "
                f"typical_acceptance_sampler_posterior_alpha = "
                f"{self.typical_acceptance_sampler_posterior_alpha}")
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def __repr__(self) -> str:
1596
1597
1598
1599
        if self.ngram_prompt_lookup_max > 0:
            draft_model = "[ngram]"
        else:
            draft_model = self.draft_model_config.model
1600
1601
1602
1603
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


1604
1605
1606
1607
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
1608
    fully_sharded_loras: bool = False
1609
    max_cpu_loras: Optional[int] = None
1610
    lora_dtype: Optional[Union[torch.dtype, str]] = None
1611
1612
1613
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256
1614
    long_lora_scaling_factors: Optional[Tuple[float]] = None
1615
1616

    def __post_init__(self):
1617
1618
1619
        # Setting the maximum rank to 256 should be able to satisfy the vast
        # majority of applications.
        possible_max_ranks = (8, 16, 32, 64, 128, 256)
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
1636
                f"max_loras ({self.max_loras})")
1637
1638
1639
1640
1641
1642

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
1643
1644
1645
        if model_config.quantization and model_config.quantization not in [
                "awq", "gptq"
        ]:
1646
            # TODO support marlin
1647
1648
            logger.warning("%s quantization is not tested with LoRA yet.",
                           model_config.quantization)
1649
1650

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
1651
1652
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
1653
1654
        if scheduler_config.chunked_prefill_enabled:
            raise ValueError("LoRA is not supported with chunked prefill yet.")
1655
1656


1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
@dataclass
class PromptAdapterConfig:
    max_prompt_adapters: int
    max_prompt_adapter_token: int
    max_cpu_prompt_adapters: Optional[int] = None
    prompt_adapter_dtype: Optional[torch.dtype] = None

    def __post_init__(self):

        if self.max_prompt_adapters < 1:
            raise ValueError(f"max_prompt_adapters "
                             f"({self.max_prompt_adapters}) must be >= 1.")
        if self.max_prompt_adapter_token == 0:
            raise ValueError("max_prompt_adapter_token must be set.")
        if self.max_cpu_prompt_adapters is None:
            self.max_cpu_prompt_adapters = self.max_prompt_adapters

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.prompt_adapter_dtype in (None, "auto"):
            self.prompt_adapter_dtype = model_config.dtype
        elif isinstance(self.prompt_adapter_dtype, str):
            self.prompt_adapter_dtype = getattr(torch,
                                                self.prompt_adapter_dtype)


1682
@dataclass
1683
class MultiModalConfig:
1684
1685
    """Controls the behavior of multimodal models."""

1686
    limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
1687
1688
1689
1690
1691
    """
    The maximum number of multi-modal input instances allowed per prompt
    for each :class:`~vllm.multimodal.MultiModalPlugin`.
    """

1692
    # TODO: Add configs to init vision tower or not.
1693

1694

1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
@dataclass
class PoolerConfig:
    """Controls the behavior of pooler in embedding model"""

    pooling_type: Optional[str] = None
    pooling_norm: Optional[bool] = None
    pooling_softmax: Optional[bool] = None
    pooling_step_tag_id: Optional[int] = None
    pooling_returned_token_ids: Optional[List[int]] = None


1706
1707
1708
1709
1710
1711
1712
1713
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

1714
_ROCM_NOT_SUPPORTED_DTYPE: List[str] = []  #
1715

1716
1717
1718

def _get_and_verify_dtype(
    config: PretrainedConfig,
1719
    dtype: Union[str, torch.dtype],
1720
1721
1722
1723
1724
1725
1726
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

1727
1728
1729
1730
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
Woosuk Kwon's avatar
Woosuk Kwon committed
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
                if config.model_type == "gemma2":
                    logger.info(
                        "For Gemma 2, we downcast float32 to bfloat16 instead "
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16.")
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
1741
1742
            else:
                torch_dtype = config_dtype
1743
        else:
1744
1745
1746
1747
1748
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
1749
    else:
1750
        raise ValueError(f"Unknown dtype: {dtype}")
1751
1752
1753
1754
1755

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
1756
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
1757
1758
1759
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
1760
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
1761
1762
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
1763
            # Casting between float16 and bfloat16 is allowed with a warning.
1764
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
1765
1766

    return torch_dtype
1767
1768
1769
1770
1771


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
1772
    disable_sliding_window: bool,
1773
    sliding_window_len: Optional[Union[int, List[Optional[int]]]],
1774
    spec_target_max_model_len: Optional[int] = None,
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
1785
1786
        # ChatGLM2
        "seq_length",
1787
1788
        # Command-R
        "model_max_length",
1789
1790
1791
1792
1793
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
1794
    # Choose the smallest "max_length" from the possible keys.
1795
    max_len_key = None
1796
    for key in possible_keys:
1797
1798
1799
1800
1801
        max_len = getattr(hf_config, key, None)
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
1802
1803
1804
1805

    # If sliding window is manually disabled, max_length should be less
    # than the sliding window length in the model config.
    if disable_sliding_window and sliding_window_len is not None:
1806
1807

        sliding_window_len_min = get_min_sliding_window(sliding_window_len)
1808
        max_len_key = "sliding_window" \
1809
1810
1811
            if sliding_window_len_min < derived_max_model_len else max_len_key
        derived_max_model_len = min(derived_max_model_len,
                                    sliding_window_len_min)
1812
1813
1814

    # If none of the keys were found in the config, use a default and
    # log a warning.
1815
    if derived_max_model_len == float("inf"):
1816
1817
1818
1819
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

1820
1821
1822
1823
1824
        if spec_target_max_model_len is not None:
            # If this is a speculative draft model, we use the max model len
            # from the target model.
            return spec_target_max_model_len

1825
1826
1827
1828
        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
1829
            "%s. Assuming the model's maximum length is %d.", possible_keys,
1830
            default_max_len)
1831
        derived_max_model_len = default_max_len
1832

1833
    rope_scaling = getattr(hf_config, "rope_scaling", None)
1834
    if rope_scaling is not None:
1835
1836
1837
        # No need to consider "type" key because of patch_rope_scaling when
        # loading HF config
        rope_type = rope_scaling["rope_type"]
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847

        if rope_type not in ("su", "longrope", "llama3"):
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that supports rope_scaling
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "with rope_scaling. Please raise an issue so we can "
                    "investigate.")

1848
1849
1850
1851
            # NOTE: rope_type == "default" does not define factor
            # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
            scaling_factor = rope_scaling.get("factor", 1.0)

1852
1853
1854
1855
            if rope_type == "yarn":
                derived_max_model_len = rope_scaling[
                    "original_max_position_embeddings"]
            derived_max_model_len *= scaling_factor
1856

1857
1858
    # If the user specified a max length, make sure it is smaller than the
    # derived length from the HF model config.
1859
    if max_model_len is None:
1860
        max_model_len = int(derived_max_model_len)
1861
    elif max_model_len > derived_max_model_len:
1862
1863
1864
1865
1866
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, "model_max_length", None)
        if model_max_length is not None and max_model_len <= model_max_length:
1867
1868
1869
1870
1871
1872
1873
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that has model_max_length
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "model_max_length in the config. Please raise an issue "
                    "so we can investigate.")
1874
        else:
1875
            msg = (
1876
                f"User-specified max_model_len ({max_model_len}) is greater "
1877
1878
                f"than the derived max_model_len ({max_len_key}="
                f"{derived_max_model_len} or model_max_length="
1879
                f"{model_max_length} in model's config.json). This may lead "
1880
1881
1882
1883
1884
1885
1886
1887
1888
                "to incorrect model outputs or CUDA errors.")
            if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
                logger.warning(
                    "%s Make sure the value is correct and within the "
                    "model context size.", msg)
            else:
                raise ValueError(
                    f"{msg} To allow overriding this maximum, set "
                    "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
1889
    return int(max_model_len)
1890
1891


1892
1893
1894
1895
1896
1897
1898
1899
def get_min_sliding_window(
        sliding_window: Union[int, List[Optional[int]]]) -> int:
    if isinstance(sliding_window, list):
        return min(s for s in sliding_window if s is not None)

    return sliding_window


1900
1901
1902
def get_served_model_name(model: str,
                          served_model_name: Optional[Union[str, List[str]]]):
    """
1903
1904
1905
1906
    If the input is a non-empty list, the first model_name in
    `served_model_name` is taken.
    If the input is a non-empty string, it is used directly.
    For cases where the input is either an empty string or an
1907
1908
1909
1910
1911
1912
1913
1914
1915
    empty list, the fallback is to use `self.model`.
    """
    if not served_model_name:
        return model
    if isinstance(served_model_name, list):
        return served_model_name[0]
    return served_model_name


1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
@dataclass
class DecodingConfig:
    """Dataclass which contains the decoding strategy of the engine"""

    # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
    guided_decoding_backend: str = 'outlines'

    def __post_init__(self):
        valid_guided_backends = ['outlines', 'lm-format-enforcer']
        backend = self.guided_decoding_backend
        if backend not in valid_guided_backends:
            raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                             f"must be one of {valid_guided_backends}")


1931
1932
1933
1934
1935
@dataclass
class ObservabilityConfig:
    """Configuration for observability."""
    otlp_traces_endpoint: Optional[str] = None

1936
1937
1938
1939
1940
1941
1942
1943
    # Collecting detailed timing information for each request can be expensive.

    # If set, collects the model forward time for the request.
    collect_model_forward_time: bool = False

    # If set, collects the model execute time for the request.
    collect_model_execute_time: bool = False

1944
    def __post_init__(self):
1945
1946
1947
1948
1949
        if not is_otel_available() and self.otlp_traces_endpoint is not None:
            raise ValueError(
                "OpenTelemetry is not available. Unable to configure "
                "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
                f"installed. Original error:\n{otel_import_error_traceback}")
1950
1951


1952
1953
1954
@dataclass
class VllmConfig:
    """Dataclass which contains all vllm-related configuration. This
1955
1956
1957
1958
1959
1960
1961
1962
    simplifies passing around the distinct configurations in the codebase.
    """

    model_config: ModelConfig
    cache_config: CacheConfig
    parallel_config: ParallelConfig
    scheduler_config: SchedulerConfig
    device_config: DeviceConfig
1963
    load_config: LoadConfig
1964
1965
1966
1967
1968
    lora_config: Optional[LoRAConfig] = None
    speculative_config: Optional[SpeculativeConfig] = None
    decoding_config: Optional[DecodingConfig] = None
    observability_config: Optional[ObservabilityConfig] = None
    prompt_adapter_config: Optional[PromptAdapterConfig] = None
1969
1970
1971
1972

    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
1973
1974
1975
        self.model_config.verify_async_output_proc(self.parallel_config,
                                                   self.speculative_config,
                                                   self.device_config)
1976
1977
1978
1979
1980
1981
1982
        self.model_config.verify_with_parallel_config(self.parallel_config)
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
1983
1984
1985
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)