config.py 89 KB
Newer Older
1
import enum
2
import json
3
from dataclasses import dataclass, field
4
5
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
                    Mapping, Optional, Set, Tuple, Type, Union)
6
7

import torch
8
from transformers import PretrainedConfig
9

10
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
from vllm.model_executor.models import ModelRegistry
14
from vllm.platforms import current_platform
15
from vllm.tracing import is_otel_available, otel_import_error_traceback
16
from vllm.transformers_utils.config import (ConfigFormat, get_config,
17
18
                                            get_hf_image_processor_config,
                                            get_hf_text_config)
19
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
20
                        print_warning_once)
21

22
23
24
if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

25
    from vllm.executor.executor_base import ExecutorBase
26
27
    from vllm.model_executor.layers.quantization.base_config import (
        QuantizationConfig)
28
    from vllm.model_executor.model_loader.loader import BaseModelLoader
29
30
    from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
        BaseTokenizerGroup)
31
32
else:
    QuantizationConfig = None
33

34
35
logger = init_logger(__name__)

36
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
37
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
38

39
40
41
42
TaskOption = Literal["auto", "generate", "embedding"]

# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]
43

44
45

class ModelConfig:
46
47
48
49
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
50
            It is also used as the content for `model_name` tag in metrics
51
52
53
54
55
            output when `served_model_name` is not specified.
        task: The task to use the model for. Each vLLM instance only supports
            one task, even if the same model can be used for multiple tasks.
            When the model only supports one task, "auto" can be used to select
            it; otherwise, you must specify explicitly which task to use.
56
        tokenizer: Name or path of the huggingface tokenizer to use.
57
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
58
59
            available, "slow" will always use the slow tokenizer, and
            "mistral" will always use the tokenizer from `mistral_common`.
60
61
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
62
63
64
65
        allowed_local_media_path: Allowing API requests to read local images or
            videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
66
67
68
69
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
70
71
72
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
73
        code_revision: The specific revision to use for the model code on
74
            Hugging Face Hub. It can be a branch name, a tag name, or a
75
            commit id. If unspecified, will use the default version.
76
77
78
        rope_scaling: Dictionary containing the scaling configuration for the
            RoPE embeddings. When using this flag, don't update
            `max_position_embeddings` to the expected new maximum.
79
80
81
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
82
83
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
84
85
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
86
87
        quantization_param_path: Path to JSON file containing scaling factors.
            Used to load KV cache scaling factors into the model when KV cache
88
89
            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
            be used to load activation and weight scaling factors when the
90
            model dtype is FP8_E4M3 on ROCm.
91
92
93
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
94
            If None, the user did not specify, so default to False.
95
96
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
97
98
99
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
100
101
102
103
        disable_sliding_window: Whether to disable sliding window. If True,
            we will disable the sliding window functionality of the model.
            If the model does not support sliding window, this argument is
            ignored.
104
105
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer.
106
        served_model_name: The model name used in metrics tag `model_name`,
107
108
            matches the model name exposed via the APIs. If multiple model
            names provided, the first name will be used. If not specified,
109
            the model name will be the same as `model`.
110
        limit_mm_per_prompt: Maximum number of data instances per modality
111
            per prompt. Only applicable for multimodal models.
112
113
114
115
        override_neuron_config: Initialize non default neuron config or
            override default neuron config that are specific to Neuron devices,
            this argument will be used to configure the neuron config that
            can not be gathered from the vllm arguments.
116
117
        config_format: The config format which shall be loaded.
            Defaults to 'auto' which defaults to 'hf'.
118
119
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor.
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        pooling_type: Used to configure the pooling method in the embedding 
            model.
        pooling_norm: Used to determine whether to normalize the pooled 
            data in the embedding model.
        pooling_softmax: Used to determine whether to softmax the pooled 
            data in the embedding model.
        pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates 
            that the score corresponding to the pooling_step_tag_id in the 
            generated sentence should be returned. Otherwise, it returns 
            the scores for all tokens.
        pooling_returned_token_ids: pooling_returned_token_ids represents a 
            list of indices for the vocabulary dimensions to be extracted, 
            such as the token IDs of good_token and bad_token in the 
            math-shepherd-mistral-7b-prm model.
134
    """
135

136
137
138
139
140
141
142
143
144
    def __init__(
            self,
            model: str,
            task: Union[TaskOption, _Task],
            tokenizer: str,
            tokenizer_mode: str,
            trust_remote_code: bool,
            dtype: Union[str, torch.dtype],
            seed: int,
145
            allowed_local_media_path: str = "",
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            revision: Optional[str] = None,
            code_revision: Optional[str] = None,
            rope_scaling: Optional[dict] = None,
            rope_theta: Optional[float] = None,
            tokenizer_revision: Optional[str] = None,
            max_model_len: Optional[int] = None,
            spec_target_max_model_len: Optional[int] = None,
            quantization: Optional[str] = None,
            quantization_param_path: Optional[str] = None,
            enforce_eager: Optional[bool] = None,
            max_seq_len_to_capture: Optional[int] = None,
            max_logprobs: int = 20,
            disable_sliding_window: bool = False,
            skip_tokenizer_init: bool = False,
            served_model_name: Optional[Union[str, List[str]]] = None,
            limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
            use_async_output_proc: bool = True,
            override_neuron_config: Optional[Dict[str, Any]] = None,
            config_format: ConfigFormat = ConfigFormat.AUTO,
            chat_template_text_format: str = "string",
            mm_processor_kwargs: Optional[Dict[str, Any]] = None,
            pooling_type: Optional[str] = None,
            pooling_norm: Optional[bool] = None,
            pooling_softmax: Optional[bool] = None,
            pooling_step_tag_id: Optional[int] = None,
            pooling_returned_token_ids: Optional[List[int]] = None) -> None:
172
        self.model = model
173
        self.tokenizer = tokenizer
174
        self.tokenizer_mode = tokenizer_mode
175
        self.trust_remote_code = trust_remote_code
176
        self.allowed_local_media_path = allowed_local_media_path
177
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
178
        self.revision = revision
179
        self.code_revision = code_revision
180
        self.rope_scaling = rope_scaling
181
        self.rope_theta = rope_theta
182
183
184
185
186
        # The tokenizer version is consistent with the model version by default.
        if tokenizer_revision is None:
            self.tokenizer_revision = revision
        else:
            self.tokenizer_revision = tokenizer_revision
187
        self.quantization = quantization
188
        self.quantization_param_path = quantization_param_path
189
        self.enforce_eager = enforce_eager
190
        self.max_seq_len_to_capture = max_seq_len_to_capture
191
        self.max_logprobs = max_logprobs
192
        self.disable_sliding_window = disable_sliding_window
193
        self.skip_tokenizer_init = skip_tokenizer_init
194

195
        self.hf_config = get_config(self.model, trust_remote_code, revision,
196
197
                                    code_revision, rope_scaling, rope_theta,
                                    config_format)
198
        self.hf_text_config = get_hf_text_config(self.hf_config)
199
200
        self.hf_image_processor_config = get_hf_image_processor_config(
            self.model, revision)
201
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
202
        self.use_async_output_proc = use_async_output_proc
203
        self.chat_template_text_format = chat_template_text_format
204
        self.mm_processor_kwargs = mm_processor_kwargs
Woosuk Kwon's avatar
Woosuk Kwon committed
205

206
207
        # Set enforce_eager to False if the value is unset.
        if self.enforce_eager is None:
208
209
            self.enforce_eager = False

210
211
212
213
214
215
216
217
218
        sliding_window = getattr(self.hf_text_config, "sliding_window", None)
        has_interleaved_attention = (sliding_window is not None) and (
            isinstance(sliding_window, list) or
            (self.hf_text_config.model_type in ["gemma2"]))

        if (not self.disable_sliding_window and has_interleaved_attention):
            sliding_window_len_min = get_min_sliding_window(
                self.hf_text_config.sliding_window)

Woosuk Kwon's avatar
Woosuk Kwon committed
219
            print_warning_once(
220
                f"{self.hf_text_config.model_type} has interleaved attention, "
Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
                "which is currently not supported by vLLM. Disabling sliding "
                "window and capping the max length to the sliding window size "
223
                f"({sliding_window_len_min}).")
Woosuk Kwon's avatar
Woosuk Kwon committed
224
225
            self.disable_sliding_window = True

226
227
228
229
        self.max_model_len = _get_and_verify_max_len(
            hf_config=self.hf_text_config,
            max_model_len=max_model_len,
            disable_sliding_window=self.disable_sliding_window,
230
231
            sliding_window_len=self.get_hf_config_sliding_window(),
            spec_target_max_model_len=spec_target_max_model_len)
232
233
        self.served_model_name = get_served_model_name(model,
                                                       served_model_name)
234
235
        self.multimodal_config = self._init_multimodal_config(
            limit_mm_per_prompt)
236
237
        if not self.skip_tokenizer_init:
            self._verify_tokenizer_mode()
238

239
240
241
        self.is_attention_free = self._init_attention_free()
        self.has_inner_state = self._init_has_inner_state()

242
243
244
245
        if current_platform.is_neuron():
            self.override_neuron_config = override_neuron_config
        else:
            self.override_neuron_config = None
246
247
248
249

        supported_tasks, task = self._resolve_task(task, self.hf_config)
        self.supported_tasks = supported_tasks
        self.task: Final = task
250
251
252
253
254
255
256
        self.pooler_config = self._init_pooler_config(
            pooling_type,
            pooling_norm,
            pooling_softmax,
            pooling_step_tag_id,
            pooling_returned_token_ids,
        )
257

258
        self._verify_quantization()
259
        self._verify_cuda_graph()
260
        self._verify_bnb_config()
261

262
263
264
265
    def _init_multimodal_config(
        self, limit_mm_per_prompt: Optional[Mapping[str, int]]
    ) -> Optional["MultiModalConfig"]:
        architectures = getattr(self.hf_config, "architectures", [])
266
        if ModelRegistry.is_multimodal_model(architectures):
267
            return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
268
269
270
271
272
273

        if limit_mm_per_prompt:
            raise ValueError("`limit_mm_per_prompt` is only supported for "
                             "multimodal models.")

        return None
274

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    def _init_pooler_config(
        self,
        pooling_type: Optional[str] = None,
        pooling_norm: Optional[bool] = None,
        pooling_softmax: Optional[bool] = None,
        pooling_step_tag_id: Optional[int] = None,
        pooling_returned_token_ids: Optional[List[int]] = None
    ) -> Optional["PoolerConfig"]:
        if self.task == "embedding":
            return PoolerConfig(
                pooling_type=pooling_type,
                pooling_norm=pooling_norm,
                pooling_softmax=pooling_softmax,
                pooling_step_tag_id=pooling_step_tag_id,
                pooling_returned_token_ids=pooling_returned_token_ids)
        return None

292
293
294
295
296
297
298
299
    def _init_attention_free(self) -> bool:
        architectures = getattr(self.hf_config, "architectures", [])
        return ModelRegistry.is_attention_free_model(architectures)

    def _init_has_inner_state(self) -> bool:
        architectures = getattr(self.hf_config, "architectures", [])
        return ModelRegistry.model_has_inner_state(architectures)

300
301
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
302
        if tokenizer_mode not in ["auto", "slow", "mistral"]:
303
304
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
305
                "either 'auto', 'slow' or 'mistral'.")
306
        self.tokenizer_mode = tokenizer_mode
307

308
309
    def _resolve_task(
        self,
310
        task_option: Union[TaskOption, _Task],
311
        hf_config: PretrainedConfig,
312
313
314
315
    ) -> Tuple[Set[_Task], _Task]:
        if task_option == "draft":
            return {"draft"}, "draft"

316
317
        architectures = getattr(hf_config, "architectures", [])

318
        task_support: Dict[_Task, bool] = {
319
320
321
322
323
            # NOTE: Listed from highest to lowest priority,
            # in case the model supports multiple of them
            "generate": ModelRegistry.is_text_generation_model(architectures),
            "embedding": ModelRegistry.is_embedding_model(architectures),
        }
324
        supported_tasks_lst: List[_Task] = [
325
326
327
328
329
330
            task for task, is_supported in task_support.items() if is_supported
        ]
        supported_tasks = set(supported_tasks_lst)

        if task_option == "auto":
            selected_task = next(iter(supported_tasks_lst))
331

332
333
334
335
            if len(supported_tasks) > 1:
                logger.info(
                    "This model supports multiple tasks: %s. "
                    "Defaulting to '%s'.", supported_tasks, selected_task)
336
        else:
337
338
339
340
341
342
343
            if task_option not in supported_tasks:
                msg = (
                    f"This model does not support the '{task_option}' task. "
                    f"Supported tasks: {supported_tasks}")
                raise ValueError(msg)

            selected_task = task_option
344

345
        return supported_tasks, selected_task
346

347
348
349
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
350
            # compressed-tensors uses a "compression_config" key
351
            quant_cfg = getattr(self.hf_config, "compression_config", None)
352
353
        return quant_cfg

354
    def _verify_quantization(self) -> None:
355
        supported_quantization = [*QUANTIZATION_METHODS]
356
357
358
359
        rocm_supported_quantization = [
            "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
            "fbgemm_fp8"
        ]
360
        optimized_quantization_methods = [
361
362
363
            "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
            "awq_marlin", "fbgemm_fp8", "compressed_tensors",
            "compressed-tensors", "experts_int8"
364
        ]
365
        tpu_supported_quantization = ["tpu_int8"]
366
        neuron_supported_quantization = ["neuron_quant"]
367
368
369
370
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
371
372
        quant_cfg = self._parse_quant_hf_config()

373
374
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
375
376

            # Detect which checkpoint is it
377
            for _, method in QUANTIZATION_METHODS.items():
378
379
380
381
382
383
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization)
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break
384

385
            # Verify quantization configurations.
386
            if self.quantization is None:
387
388
                self.quantization = quant_method
            elif self.quantization != quant_method:
389
390
                raise ValueError(
                    "Quantization method specified in the model config "
391
                    f"({quant_method}) does not match the quantization "
392
393
394
395
396
397
398
399
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
400
            if current_platform.is_rocm(
401
            ) and self.quantization not in rocm_supported_quantization:
402
                raise ValueError(
403
404
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm.")
405
            if current_platform.is_tpu(
406
407
408
409
            ) and self.quantization not in tpu_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in TPU Backend.")
410
            if self.quantization not in optimized_quantization_methods:
411
                logger.warning(
412
                    "%s quantization is not fully "
413
                    "optimized yet. The speed can be slower than "
414
                    "non-quantized models.", self.quantization)
415
            if (self.quantization == "awq" and current_platform.is_rocm()
416
417
418
419
420
                    and not envs.VLLM_USE_TRITON_AWQ):
                logger.warning(
                    "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
                    " is not set, enabling VLLM_USE_TRITON_AWQ.")
                envs.VLLM_USE_TRITON_AWQ = True
421
            if current_platform.is_neuron(
422
423
424
425
            ) and self.quantization not in neuron_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in Neuron Backend.")
426

427
    def _verify_cuda_graph(self) -> None:
428
429
430
431
        if self.max_seq_len_to_capture is None:
            self.max_seq_len_to_capture = self.max_model_len
        self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
                                          self.max_model_len)
432

433
434
    def _verify_bnb_config(self) -> None:
        """
435
        The current version of bitsandbytes (0.44.0) with 8-bit models does not
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        yet support CUDA graph.
        """
        is_bitsandbytes = self.quantization == "bitsandbytes"
        has_quantization_config = (getattr(self.hf_config,
                                           "quantization_config", None)
                                   is not None)
        is_8bit = (self.hf_config.quantization_config.get(
            "load_in_8bit", False) if has_quantization_config else False)
        if all([
                is_bitsandbytes,
                has_quantization_config,
                is_8bit,
                not self.enforce_eager,
        ]):
            logger.warning(
                "CUDA graph is not supported on BitAndBytes 8bit yet, "
                "fallback to the eager mode.")
            self.enforce_eager = True

455
456
457
458
459
460
461
462
463
464
465
466
    def verify_async_output_proc(self, parallel_config, speculative_config,
                                 device_config) -> None:
        if not self.use_async_output_proc:
            # Nothing to check
            return

        if parallel_config.pipeline_parallel_size > 1:
            logger.warning("Async output processing can not be enabled "
                           "with pipeline parallel")
            self.use_async_output_proc = False
            return

467
468
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
469
        if device_config.device_type not in ("cuda", "tpu", "xpu"):
470
            logger.warning(
471
                "Async output processing is only supported for CUDA, TPU, XPU. "
472
                "Disabling it for other platforms.")
473
474
475
476
477
478
479
480
481
            self.use_async_output_proc = False
            return

        if envs.VLLM_USE_RAY_SPMD_WORKER:
            logger.warning(
                "Async output processing can not be enabled with ray spmd")
            self.use_async_output_proc = False
            return

482
483
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
484
        if device_config.device_type == "cuda" and self.enforce_eager:
485
486
487
488
489
490
491
492
493
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            self.use_async_output_proc = not self.enforce_eager
            return

        # Async postprocessor is not necessary with embedding mode
        # since there is no token generation
494
        if self.task == "embedding":
495
496
            self.use_async_output_proc = False

497
498
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
499
500
501
502
503
        if speculative_config:
            logger.warning("Async output processing is not supported with"
                           " speculative decoding currently.")
            self.use_async_output_proc = False

504
505
506
507
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
508
509
        total_num_attention_heads = getattr(self.hf_text_config,
                                            "num_attention_heads", 0)
510
511
512
513
514
515
516
517
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        pipeline_parallel_size = parallel_config.pipeline_parallel_size
518
519
520
521
522
523
524
525
526
527
528
        if pipeline_parallel_size > 1:
            architectures = getattr(self.hf_config, "architectures", [])
            if not ModelRegistry.is_pp_supported_model(architectures):
                raise NotImplementedError(
                    "Pipeline parallelism is not supported for this model. "
                    "Supported models implement the `SupportsPP` interface.")

            if self.use_async_output_proc:
                logger.warning("Async output processor is not supported with "
                               "pipeline parallelism currently. Disabling it.")
                self.use_async_output_proc = False
529

530
531
    def get_hf_config_sliding_window(
            self) -> Union[Optional[int], List[Optional[int]]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
532
        """Get the sliding window size, or None if disabled."""
533
534
535
536

        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
        # addition to sliding window size. We check if that field is present
        # and if it's False, return None.
537
538
        if (hasattr(self.hf_text_config, "use_sliding_window")
                and not self.hf_text_config.use_sliding_window):
539
            return None
540
        return getattr(self.hf_text_config, "sliding_window", None)
541

542
    def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
543
544
545
546
547
548
549
550
        """Get the sliding window size, or None if disabled.
        """
        # If user disables sliding window, return None.
        if self.disable_sliding_window:
            return None
        # Otherwise get the value from the hf config.
        return self.get_hf_config_sliding_window()

551
    def get_vocab_size(self) -> int:
552
        return self.hf_text_config.vocab_size
553

554
    def get_hidden_size(self) -> int:
555
        return self.hf_text_config.hidden_size
556
557

    def get_head_size(self) -> int:
wangding zeng's avatar
wangding zeng committed
558
559
560
561
562
563
        # TODO remove hard code
        if hasattr(self.hf_text_config, "model_type"
                   ) and self.hf_text_config.model_type == 'deepseek_v2':
            # FlashAttention supports only head_size 32, 64, 128, 256,
            # we need to pad head_size 192 to 256
            return 256
564
565
566
567

        if self.is_attention_free:
            return 0

568
569
        if hasattr(self.hf_text_config, "head_dim"):
            return self.hf_text_config.head_dim
570
        # FIXME(woosuk): This may not be true for all models.
571
572
        return (self.hf_text_config.hidden_size //
                self.hf_text_config.num_attention_heads)
573

574
575
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
576
        # For GPTBigCode & Falcon:
577
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
578
579
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
580
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
581
        new_decoder_arch_falcon = (
582
            self.hf_config.model_type in falcon_model_types
583
            and getattr(self.hf_config, "new_decoder_architecture", False))
584
        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
585
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
586
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
587
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
588
            return 1
589

590
        # For DBRX and MPT
591
592
593
594
595
        if self.hf_config.model_type == "mpt":
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type == "dbrx":
596
597
598
            return getattr(self.hf_config.attn_config, "kv_n_heads",
                           self.hf_config.num_attention_heads)

599
600
601
        if self.is_attention_free:
            return 0

602
603
604
605
606
607
608
609
610
611
        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
612
            num_kv_heads = getattr(self.hf_text_config, attr, None)
613
614
615
616
617
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
618
        return self.hf_text_config.num_attention_heads
619
620
621
622
623
624
625
626
627
628

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
629

630
631
    def get_num_attention_heads(self,
                                parallel_config: "ParallelConfig") -> int:
632
633
        num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
        return num_heads // parallel_config.tensor_parallel_size
634

635
    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
636
        from vllm.distributed.utils import get_pp_indices
Mor Zusman's avatar
Mor Zusman committed
637
638
        total_num_hidden_layers = getattr(self.hf_text_config,
                                          "num_hidden_layers", 0)
639
640
641
642
        pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
        pp_size = parallel_config.pipeline_parallel_size
        start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
        return end - start
643

644
645
646
647
    def get_num_attention_layers(self,
                                 parallel_config: "ParallelConfig") -> int:
        if self.is_attention_free:
            return 0
Mor Zusman's avatar
Mor Zusman committed
648
649
650

        num_layers = self.get_num_layers(parallel_config)

651
652
653
654
        # Transformers supports layers_block_type @property
        layers = getattr(self.hf_config, "layers_block_type",
                         ["attention"] * num_layers)
        return len([t for t in layers if t == "attention"])
Mor Zusman's avatar
Mor Zusman committed
655

656
657
658
659
660
661
662
663
664
665
666
667
    def get_multimodal_config(self) -> "MultiModalConfig":
        """
        Get the multimodal configuration of the model.

        Raises:
            ValueError: If the model is not multimodal.
        """
        if self.multimodal_config is None:
            raise ValueError("The model is not multimodal.")

        return self.multimodal_config

668
669
670
    @property
    def is_encoder_decoder_model(self) -> bool:
        """Extract the HF encoder/decoder model flag."""
671
672
673
        return getattr(self.hf_config, "is_encoder_decoder", False) or (
            (hasattr(self.hf_config, "text_config") and getattr(
                self.hf_config.text_config, "is_encoder_decoder", False)))
674

675
676
677
678
    @property
    def is_multimodal_model(self) -> bool:
        return self.multimodal_config is not None

679
680

class CacheConfig:
681
682
683
684
685
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
686
            vLLM execution.
687
        swap_space: Size of the CPU swap space per GPU (in GiB).
688
        cache_dtype: Data type for kv cache storage.
689
        num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
690
            profiled num_gpu_blocks if specified. Does nothing if None.
691
    """
692

693
694
695
696
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
697
        swap_space: float,
698
        cache_dtype: str,
699
        is_attention_free: bool = False,
700
        num_gpu_blocks_override: Optional[int] = None,
701
        sliding_window: Optional[int] = None,
702
        enable_prefix_caching: bool = False,
703
        cpu_offload_gb: float = 0,
704
705
706
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
707
        self.swap_space_bytes = swap_space * GiB_bytes
708
        self.num_gpu_blocks_override = num_gpu_blocks_override
709
        self.cache_dtype = cache_dtype
710
        self.is_attention_free = is_attention_free
711
        self.sliding_window = sliding_window
712
        self.enable_prefix_caching = enable_prefix_caching
713
        self.cpu_offload_gb = cpu_offload_gb
714

715
        self._verify_args()
716
        self._verify_cache_dtype()
717
        self._verify_prefix_caching()
718
719

        # Will be set after profiling.
720
721
        self.num_gpu_blocks: Optional[int] = None
        self.num_cpu_blocks: Optional[int] = None
722

723
    def metrics_info(self):
724
725
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
726
727
        return {key: str(value) for key, value in self.__dict__.items()}

728
729
730
731
732
733
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

734
735
736
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
737
        elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
738
            logger.info(
739
740
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
741
742
                "Meanwhile, it may cause accuracy drop without a proper "
                "scaling factor")
743
744
745
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

746
747
748
749
750
751
752
753
754
    def _verify_prefix_caching(self) -> None:
        if not self.enable_prefix_caching:
            return

        if self.sliding_window is not None:
            raise NotImplementedError(
                "Prefix caching is not supported with sliding window. "
                "Run with --disable-sliding-window to use prefix caching.")

755
756
757
758
759
760
761
762
763
764
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

765
766
767
        msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
               f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
               "is allocated for the swap space.")
768
769
770
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
771
            logger.warning("Possibly too large swap space. %s", msg)
772

773

774
775
776
@dataclass
class TokenizerPoolConfig:
    """Configuration for the tokenizer pool.
777

778
779
780
781
782
783
784
785
    Args:
        pool_size: Number of tokenizer workers in the pool.
        pool_type: Type of the pool.
        extra_config: Additional config for the pool.
            The way the config will be used depends on the
            pool type.
    """
    pool_size: int
786
    pool_type: Union[str, Type["BaseTokenizerGroup"]]
787
788
789
    extra_config: dict

    def __post_init__(self):
790
791
        if self.pool_type not in ("ray", ) and not isinstance(
                self.pool_type, type):
792
793
794
795
796
797
            raise ValueError(f"Unknown pool type: {self.pool_type}")
        if not isinstance(self.extra_config, dict):
            raise ValueError("extra_config must be a dictionary.")

    @classmethod
    def create_config(
798
799
        cls, tokenizer_pool_size: int,
        tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]],
800
801
802
        tokenizer_pool_extra_config: Optional[Union[str, dict]]
    ) -> Optional["TokenizerPoolConfig"]:
        """Create a TokenizerPoolConfig from the given parameters.
803

804
        If tokenizer_pool_size is 0, return None.
805

806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        Args:
            tokenizer_pool_size: Number of tokenizer workers in the pool.
            tokenizer_pool_type: Type of the pool.
            tokenizer_pool_extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type. This can be a JSON string (will be parsed).
        """
        if tokenizer_pool_size:
            if isinstance(tokenizer_pool_extra_config, str):
                tokenizer_pool_extra_config_parsed = json.loads(
                    tokenizer_pool_extra_config)
            else:
                tokenizer_pool_extra_config_parsed = (
                    tokenizer_pool_extra_config or {})
            tokenizer_pool_config = cls(tokenizer_pool_size,
                                        tokenizer_pool_type,
                                        tokenizer_pool_extra_config_parsed)
        else:
            tokenizer_pool_config = None
        return tokenizer_pool_config


828
829
830
831
832
833
834
class LoadFormat(str, enum.Enum):
    AUTO = "auto"
    PT = "pt"
    SAFETENSORS = "safetensors"
    NPCACHE = "npcache"
    DUMMY = "dummy"
    TENSORIZER = "tensorizer"
835
    SHARDED_STATE = "sharded_state"
836
    GGUF = "gguf"
837
    BITSANDBYTES = "bitsandbytes"
838
    MISTRAL = "mistral"
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857


@dataclass
class LoadConfig:
    """
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
            "tensorizer" will use CoreWeave's tensorizer library for
                fast weight loading.
858
            "bitsandbytes" will load nf4 type weights.
859
        ignore_patterns: The list of patterns to ignore when loading the model.
860
            Default to "original/**/*" to avoid repeated loading of llama's
861
            checkpoints.
862

863
864
865
866
867
868
    """

    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(
        default_factory=dict)
869
    ignore_patterns: Optional[Union[List[str], str]] = None
870
871
872
873
874
875
876
877

    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = json.loads(
                model_loader_extra_config)
        self._verify_load_format()

878
879
880
881
882
883
884
        if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
            logger.info(
                "Ignoring the following patterns when downloading weights: %s",
                self.ignore_patterns)
        else:
            self.ignore_patterns = ["original/**/*"]

885
886
887
888
889
890
891
892
    def _verify_load_format(self) -> None:
        if not isinstance(self.load_format, str):
            return

        load_format = self.load_format.lower()
        self.load_format = LoadFormat(load_format)

        rocm_not_supported_load_format: List[str] = []
893
894
        if current_platform.is_rocm(
        ) and load_format in rocm_not_supported_load_format:
895
896
897
898
899
900
901
902
903
904
            rocm_supported_load_format = [
                f for f in LoadFormat.__members__
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format '{load_format}' is not supported in ROCm. "
                f"Supported load formats are "
                f"{rocm_supported_load_format}")


905
class ParallelConfig:
906
907
908
909
910
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
911
        worker_use_ray: Deprecated, use distributed_executor_backend instead.
zspo's avatar
zspo committed
912
913
914
        max_parallel_loading_workers: Maximum number of multiple batches
            when load model sequentially. To avoid RAM OOM when using tensor
            parallel and large models.
915
916
        disable_custom_all_reduce: Disable the custom all-reduce kernel and
            fall back to NCCL.
917
918
        tokenizer_pool_config: Config for the tokenizer pool.
            If None, will use synchronous tokenization.
919
920
        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
921
        placement_group: ray distributed model workers placement group.
922
923
924
925
        distributed_executor_backend: Backend to use for distributed model
            workers, either "ray" or "mp" (multiprocessing). If either
            pipeline_parallel_size or tensor_parallel_size is greater than 1,
            will default to "ray" if Ray is installed or "mp" otherwise.
926
    """
927

928
929
930
931
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
932
        worker_use_ray: Optional[bool] = None,
933
        max_parallel_loading_workers: Optional[int] = None,
934
        disable_custom_all_reduce: bool = False,
935
        tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
936
        ray_workers_use_nsight: bool = False,
937
        placement_group: Optional["PlacementGroup"] = None,
938
939
        distributed_executor_backend: Optional[Union[
            str, Type["ExecutorBase"]]] = None,
940
941
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
942
        self.tensor_parallel_size = tensor_parallel_size
943
        self.distributed_executor_backend = distributed_executor_backend
944
        self.max_parallel_loading_workers = max_parallel_loading_workers
945
        self.disable_custom_all_reduce = disable_custom_all_reduce
946
        self.tokenizer_pool_config = tokenizer_pool_config
947
        self.ray_workers_use_nsight = ray_workers_use_nsight
948
        self.placement_group = placement_group
949
        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
950

951
952
953
        if worker_use_ray:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
954
            elif not self.use_ray:
955
956
957
958
                raise ValueError(f"worker-use-ray can't be used with "
                                 f"distributed executor backend "
                                 f"'{self.distributed_executor_backend}'.")

959
960
961
962
963
964
965
        if current_platform.is_tpu() and self.world_size > 1:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
            if self.distributed_executor_backend != "ray":
                raise ValueError(
                    "TPU backend only supports Ray for distributed inference.")

966
        if self.distributed_executor_backend is None and self.world_size > 1:
967
968
969
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

970
            from vllm.executor import ray_utils
971
            backend = "mp"
972
            ray_found = ray_utils.ray_is_available()
973
            if (current_platform.is_cuda()
974
                    and cuda_device_count_stateless() < self.world_size):
975
976
                if not ray_found:
                    raise ValueError("Unable to load Ray which is "
977
978
979
                                     "required for multi-node inference, "
                                     "please install Ray with `pip install "
                                     "ray`.") from ray_utils.ray_import_err
980
981
                backend = "ray"
            elif ray_found:
982
                if self.placement_group:
983
                    backend = "ray"
984
985
986
987
988
989
                else:
                    from ray import is_initialized as ray_is_initialized
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
                        if get_current_placement_group():
                            backend = "ray"
990
991
992
            self.distributed_executor_backend = backend
            logger.info("Defaulting to use %s for distributed inference",
                        backend)
993

994
        self._verify_args()
995
        self.rank: int = 0
996

997
998
999
1000
1001
1002
    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
            and self.distributed_executor_backend.uses_ray)

1003
    def _verify_args(self) -> None:
1004
1005
1006
1007
1008
1009
1010
        # Lazy import to avoid circular import
        from vllm.executor.executor_base import ExecutorBase

        if self.distributed_executor_backend not in (
                "ray", "mp", None) and not (isinstance(
                    self.distributed_executor_backend, type) and issubclass(
                        self.distributed_executor_backend, ExecutorBase)):
1011
            raise ValueError(
1012
1013
1014
1015
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
                "values are 'ray', 'mp' or custom ExecutorBase subclass.")
        if self.use_ray:
1016
1017
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()
1018
        if current_platform.is_rocm():
1019
1020
1021
1022
            self.disable_custom_all_reduce = True
            logger.info(
                "Disabled the custom all-reduce kernel because it is not "
                "supported on AMD GPUs.")
1023
        if self.ray_workers_use_nsight and not self.use_ray:
1024
1025
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
1026

1027
1028

class SchedulerConfig:
1029
1030
1031
    """Scheduler configuration.

    Args:
1032
        task: The task to use the model for.
1033
1034
1035
1036
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
1037
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
1038
            and generated text).
1039
1040
1041
1042
        num_lookahead_slots: The number of slots to allocate per sequence per
            step, beyond the known token ids. This is used in speculative
            decoding to store KV activations of tokens which may or may not be
            accepted.
1043
1044
        delay_factor: Apply a delay (of delay factor multiplied by previous
            prompt latency) before scheduling next prompt.
1045
1046
        enable_chunked_prefill: If True, prefill requests can be chunked based
            on the remaining max_num_batched_tokens.
1047
        preemption_mode: Whether to perform preemption by swapping or
1048
1049
1050
1051
1052
            recomputation. If not specified, we determine the mode as follows:
            We use recomputation by default since it incurs lower overhead than
            swapping. However, when the sequence group has multiple sequences
            (e.g., beam search), recomputation is not currently supported. In
            such a case, we use swapping instead.
1053
1054
1055
1056
        send_delta_data: Private API. If used, scheduler sends delta data to
            workers instead of an entire data. It should be enabled only
            when SPMD worker architecture is enabled. I.e.,
            VLLM_USE_RAY_SPMD_WORKER=1
1057
        policy: The scheduling policy to use. "fcfs" (default) or "priority".
1058
    """
1059

1060
    def __init__(self,
1061
                 task: _Task,
1062
1063
1064
1065
1066
1067
                 max_num_batched_tokens: Optional[int],
                 max_num_seqs: int,
                 max_model_len: int,
                 num_lookahead_slots: int = 0,
                 delay_factor: float = 0.0,
                 enable_chunked_prefill: bool = False,
1068
                 is_multimodal_model: bool = False,
1069
                 preemption_mode: Optional[str] = None,
1070
                 num_scheduler_steps: int = 1,
1071
                 multi_step_stream_outputs: bool = False,
1072
1073
                 send_delta_data: bool = False,
                 policy: str = "fcfs") -> None:
1074
        if max_num_batched_tokens is None:
1075
            if enable_chunked_prefill:
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
                if num_scheduler_steps > 1:
                    # Multi-step Chunked-Prefill doesn't allow prompt-chunking
                    # for now. Have max_num_batched_tokens set to max_model_len
                    # so we don't reject sequences on account of a short
                    # max_num_batched_tokens.
                    max_num_batched_tokens = max(max_model_len, 2048)
                else:
                    # It is the values that have the best balance between ITL
                    # and TTFT on A100. Note it is not optimized for throughput.
                    max_num_batched_tokens = 512
1086
1087
1088
            else:
                # If max_model_len is too short, use 2048 as the default value
                # for higher throughput.
1089
1090
                max_num_batched_tokens = max(max_model_len, 2048)

1091
            if task == "embedding":
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
                # For embedding, choose specific value for higher throughput
                max_num_batched_tokens = max(
                    max_num_batched_tokens,
                    _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
                )
            if is_multimodal_model:
                # The value needs to be at least the number of multimodal tokens
                max_num_batched_tokens = max(
                    max_num_batched_tokens,
                    _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
                )

        self.max_num_batched_tokens = max_num_batched_tokens

1106
        if enable_chunked_prefill:
1107
1108
            logger.info(
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
1109
                self.max_num_batched_tokens)
1110

1111
        self.task: Final = task
1112
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
1113
        self.max_model_len = max_model_len
1114
1115
        self.num_lookahead_slots = num_lookahead_slots
        self.delay_factor = delay_factor
1116
        self.chunked_prefill_enabled = enable_chunked_prefill
1117
        self.preemption_mode = preemption_mode
1118
        self.num_scheduler_steps = num_scheduler_steps
1119
        self.multi_step_stream_outputs = multi_step_stream_outputs
1120
        self.send_delta_data = send_delta_data
1121
        self.policy = policy
1122
1123
1124
        self._verify_args()

    def _verify_args(self) -> None:
1125
1126
        if (self.max_num_batched_tokens < self.max_model_len
                and not self.chunked_prefill_enabled):
1127
1128
1129
1130
1131
1132
1133
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
1134

1135
1136
1137
1138
1139
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
1140

1141
1142
1143
1144
1145
1146
        if self.num_lookahead_slots < 0:
            raise ValueError(
                "num_lookahead_slots "
                f"({self.num_lookahead_slots}) must be greater than or "
                "equal to 0.")

1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
        if self.num_scheduler_steps < 1:
            raise ValueError(
                "num_scheduler_steps "
                f"({self.num_scheduler_steps}) must be greater than or "
                "equal to 1.")

    @property
    def is_multi_step(self) -> bool:
        return self.num_scheduler_steps > 1

1157

1158
class DeviceConfig:
1159
    device: Optional[torch.device]
1160

1161
1162
1163
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
1164
1165
            if current_platform.is_cuda_alike():
                self.device_type = "cuda"
1166
            elif current_platform.is_neuron():
1167
                self.device_type = "neuron"
1168
            elif current_platform.is_openvino():
1169
                self.device_type = "openvino"
1170
            elif current_platform.is_tpu():
1171
                self.device_type = "tpu"
1172
            elif current_platform.is_cpu():
1173
                self.device_type = "cpu"
1174
            elif current_platform.is_xpu():
1175
                self.device_type = "xpu"
1176
            else:
1177
                raise RuntimeError("Failed to infer device type")
1178
1179
1180
1181
1182
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
1183
        if self.device_type in ["neuron", "openvino"]:
1184
            self.device = torch.device("cpu")
1185
1186
        elif self.device_type in ["tpu"]:
            self.device = None
1187
1188
1189
1190
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

1191

1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
class SpeculativeConfig:
    """Configuration for speculative decoding.

    The configuration is currently specialized to draft-model speculative
    decoding with top-1 proposals.
    """

    @staticmethod
    def maybe_create_spec_config(
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        target_dtype: str,
        speculative_model: Optional[str],
1205
        speculative_model_quantization: Optional[str],
1206
        speculative_draft_tensor_parallel_size: Optional[int],
1207
        num_speculative_tokens: Optional[int],
1208
        speculative_disable_mqa_scorer: Optional[bool],
1209
1210
        speculative_max_model_len: Optional[int],
        enable_chunked_prefill: bool,
1211
        disable_log_stats: bool,
1212
        speculative_disable_by_batch_size: Optional[int],
1213
1214
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1215
1216
1217
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: Optional[float],
        typical_acceptance_sampler_posterior_alpha: Optional[float],
1218
        disable_logprobs: Optional[bool],
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
    ) -> Optional["SpeculativeConfig"]:
        """Create a SpeculativeConfig if possible, else return None.

        This function attempts to create a SpeculativeConfig object based on the
        provided parameters. If the necessary conditions are met, it returns an
        instance of SpeculativeConfig. Otherwise, it returns None.

        Args:
            target_model_config (ModelConfig): The configuration of the target
                model.
            target_parallel_config (ParallelConfig): The parallel configuration
                for the target model.
            target_dtype (str): The data type used for the target model.
            speculative_model (Optional[str]): The name of the speculative
                model, if provided.
1234
1235
1236
            speculative_model_quantization (Optional[str]): Quantization method
                that was used to quantize the speculative model weights. If
                None, we assume the model weights are not quantized.
1237
1238
            speculative_draft_tensor_parallel_size (Optional[int]): The degree
                of the tensor parallelism for the draft model.
1239
            num_speculative_tokens (Optional[int]): The number of speculative
1240
1241
                tokens, if provided. Will default to the number in the draft
                model config if present, otherwise is required.
1242
1243
1244
            speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
                scorer for the speculative model and fall back to batch
                expansion for scoring.
1245
1246
1247
1248
1249
1250
            speculative_max_model_len (Optional[int]): The maximum model len of
                the speculative model. Used when testing the ability to skip
                speculation for some sequences.
            enable_chunked_prefill (bool): Whether vLLM is configured to use
                chunked prefill or not. Used for raising an error since its not
                yet compatible with spec decode.
1251
1252
1253
            speculative_disable_by_batch_size (Optional[int]): Disable
                speculative decoding for new incoming requests when the number
                of enqueue requests  is larger than this value, if provided.
1254
1255
1256
1257
            ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
                window, if provided.
            ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
                window, if provided.
1258
1259
1260
1261
1262
1263
1264
1265
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
1266
                accepted. This threshold is used only when we use the
1267
1268
1269
1270
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1271
1272
1273
1274
1275
            disable_logprobs (Optional[bool]): If set to True, token log
                probabilities are not returned during speculative decoding.
                If set to False, token log probabilities are returned
                according to the log probability settings in SamplingParams.
                If not specified, it defaults to True.
1276

1277
1278
1279
1280
1281
        Returns:
            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                the necessary conditions are met, else None.
        """

1282
1283
1284
1285
        if speculative_model is None:
            if num_speculative_tokens is not None:
                raise ValueError("num_speculative_tokens was provided without "
                                 "speculative_model.")
1286
1287
            return None

1288
1289
1290
1291
1292
1293
        if (speculative_disable_by_batch_size is not None
                and speculative_disable_by_batch_size < 2):
            raise ValueError("Expect the batch size threshold of disabling "
                             "speculative decoding is > 1, but got "
                             f"{speculative_disable_by_batch_size=}")

1294
1295
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
1296
1297
1298
1299
1300
        if enable_chunked_prefill:
            raise ValueError(
                "Speculative decoding and chunked prefill are "
                f"currently mutually exclusive ({enable_chunked_prefill=}).")

1301
1302
        # TODO: The user should be able to specify revision/max model len
        # for the draft model. It is not currently supported.
1303
1304
        draft_revision = None
        draft_code_revision = None
1305
        draft_quantization = speculative_model_quantization
1306

1307
1308
        if speculative_model == "[ngram]":
            if ngram_prompt_lookup_min is None:
1309
1310
1311
1312
1313
1314
1315
1316
                ngram_prompt_lookup_min = 1
            if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
                raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
            if ngram_prompt_lookup_min < 1:
                raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
            if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
                raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
                                 f"larger than {ngram_prompt_lookup_max=}")
1317

1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            draft_model_config = target_model_config
            draft_parallel_config = target_parallel_config
        else:
            ngram_prompt_lookup_max = 0
            ngram_prompt_lookup_min = 0
            draft_model_config = ModelConfig(
                model=speculative_model,
1328
                task="draft",
1329
1330
1331
                tokenizer=target_model_config.tokenizer,
                tokenizer_mode=target_model_config.tokenizer_mode,
                trust_remote_code=target_model_config.trust_remote_code,
1332
1333
                allowed_local_media_path=target_model_config.
                allowed_local_media_path,
1334
1335
1336
1337
1338
1339
                dtype=target_model_config.dtype,
                seed=target_model_config.seed,
                revision=draft_revision,
                code_revision=draft_code_revision,
                tokenizer_revision=target_model_config.tokenizer_revision,
                max_model_len=None,
1340
                spec_target_max_model_len=target_model_config.max_model_len,
1341
1342
                quantization=draft_quantization,
                enforce_eager=target_model_config.enforce_eager,
1343
1344
                max_seq_len_to_capture=target_model_config.
                max_seq_len_to_capture,
1345
1346
1347
                max_logprobs=target_model_config.max_logprobs,
            )

1348
            draft_hf_config = draft_model_config.hf_config
1349

1350
1351
1352
1353
1354
            if (num_speculative_tokens is not None
                    and hasattr(draft_hf_config, "num_lookahead_tokens")):
                draft_hf_config.num_lookahead_tokens = num_speculative_tokens

            n_predict = getattr(draft_hf_config, "n_predict", None)
1355
1356
1357
1358
1359
1360
1361
1362
            if n_predict is not None:
                if num_speculative_tokens is None:
                    # Default to max value defined in draft model config.
                    num_speculative_tokens = n_predict
                elif num_speculative_tokens > n_predict:
                    # Verify provided value doesn't exceed the maximum
                    # supported by the draft model.
                    raise ValueError(
1363
1364
1365
                        "This speculative model supports a maximum of "
                        f"num_speculative_tokens={n_predict}, but "
                        f"{num_speculative_tokens=} was provided.")
1366

1367
1368
1369
1370
1371
1372
1373
1374
1375
            draft_model_config.max_model_len = (
                SpeculativeConfig._maybe_override_draft_max_model_len(
                    speculative_max_model_len,
                    draft_model_config.max_model_len,
                    target_model_config.max_model_len,
                ))

            draft_parallel_config = (
                SpeculativeConfig.create_draft_parallel_config(
1376
                    target_parallel_config,
1377
                    speculative_draft_tensor_parallel_size, draft_hf_config))
1378

1379
1380
1381
1382
1383
1384
        if num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative_model unless the draft model config contains an "
                "n_predict parameter.")

1385
1386
1387
1388
        if typical_acceptance_sampler_posterior_threshold is None:
            typical_acceptance_sampler_posterior_threshold = 0.09
        if typical_acceptance_sampler_posterior_alpha is None:
            typical_acceptance_sampler_posterior_alpha = 0.3
1389
1390
        if disable_logprobs is None:
            disable_logprobs = True
1391

1392
1393
1394
1395
        return SpeculativeConfig(
            draft_model_config,
            draft_parallel_config,
            num_speculative_tokens,
1396
            speculative_disable_mqa_scorer,
1397
            speculative_disable_by_batch_size,
1398
1399
            ngram_prompt_lookup_max,
            ngram_prompt_lookup_min,
1400
1401
1402
1403
1404
            draft_token_acceptance_method=draft_token_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=\
                typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=\
                typical_acceptance_sampler_posterior_alpha,
1405
1406
            disable_logprobs=disable_logprobs,
            disable_log_stats=disable_log_stats,
1407
1408
        )

1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
    @staticmethod
    def _maybe_override_draft_max_model_len(
        speculative_max_model_len: Optional[int],
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:

            if speculative_max_model_len > draft_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {draft_max_model_len=}")

            if speculative_max_model_len > target_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {target_max_model_len=}")

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

1444
1445
    @staticmethod
    def create_draft_parallel_config(
1446
        target_parallel_config: ParallelConfig,
1447
1448
        speculative_draft_tensor_parallel_size: Optional[int],
        draft_hf_config: PretrainedConfig,
1449
    ) -> ParallelConfig:
1450
1451
        """Create a parallel config for use by the draft worker.

1452
        This is mostly a copy of the target parallel config, except the tp_size.
1453
        """
1454
        if speculative_draft_tensor_parallel_size is None:
1455
1456
1457
1458
1459
1460
1461
1462
1463
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "MLPSpeculator cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1")
            else:
                speculative_draft_tensor_parallel_size = \
                    target_parallel_config.tensor_parallel_size
1464
1465
        elif speculative_draft_tensor_parallel_size not in (
                1, target_parallel_config.tensor_parallel_size):
1466
            raise ValueError(
1467
                f"{speculative_draft_tensor_parallel_size=} cannot be "
1468
                f"other value than 1 or target model tensor_parallel_size")
1469

1470
1471
1472
        draft_parallel_config = ParallelConfig(
            pipeline_parallel_size=target_parallel_config.
            pipeline_parallel_size,
1473
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
1474
1475
            distributed_executor_backend=target_parallel_config.
            distributed_executor_backend,
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
            max_parallel_loading_workers=target_parallel_config.
            max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.
            disable_custom_all_reduce,
            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
            ray_workers_use_nsight=target_parallel_config.
            ray_workers_use_nsight,
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

    def __init__(
        self,
        draft_model_config: ModelConfig,
        draft_parallel_config: ParallelConfig,
        num_speculative_tokens: int,
1493
        speculative_disable_mqa_scorer: Optional[bool],
1494
1495
1496
        speculative_disable_by_batch_size: Optional[int],
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1497
1498
1499
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
1500
        disable_logprobs: bool,
1501
        disable_log_stats: bool,
1502
1503
1504
1505
1506
1507
1508
1509
    ):
        """Create a SpeculativeConfig object.

        Args:
            draft_model_config: ModelConfig for the draft model.
            draft_parallel_config: ParallelConfig for the draft model.
            num_speculative_tokens: The number of tokens to sample from the
                draft model before scoring with the target model.
1510
1511
1512
1513
1514
            speculative_disable_by_batch_size: Disable speculative
                decoding for new incoming requests when the number of
                enqueue requests is larger than this value.
            ngram_prompt_lookup_max: Max size of ngram token window.
            ngram_prompt_lookup_min: Min size of ngram token window.
1515
1516
1517
1518
1519
1520
1521
1522
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
1523
                accepted. This threshold is used only when we use the
1524
1525
1526
1527
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1528
            disable_logprobs: If set to True, token log probabilities will not
1529
                be returned even if requested by sampling parameters. This
1530
1531
1532
1533
                reduces latency by skipping logprob calculation in proposal
                sampling, target sampling, and after accepted tokens are
                determined. If set to False, log probabilities will be
                returned.
1534
1535
            disable_log_stats: Whether to disable periodic printing of stage
                times in speculative decoding.
1536
1537
1538
1539
        """
        self.draft_model_config = draft_model_config
        self.draft_parallel_config = draft_parallel_config
        self.num_speculative_tokens = num_speculative_tokens
1540
        self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
1541
1542
1543
1544
        self.speculative_disable_by_batch_size = \
            speculative_disable_by_batch_size
        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
1545
1546
1547
1548
1549
        self.draft_token_acceptance_method = draft_token_acceptance_method
        self.typical_acceptance_sampler_posterior_threshold = \
            typical_acceptance_sampler_posterior_threshold
        self.typical_acceptance_sampler_posterior_alpha = \
            typical_acceptance_sampler_posterior_alpha
1550
        self.disable_logprobs = disable_logprobs
1551
        self.disable_log_stats = disable_log_stats
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562

        self._verify_args()

    def _verify_args(self) -> None:
        if self.num_speculative_tokens <= 0:
            raise ValueError("Expected num_speculative_tokens to be greater "
                             f"than zero ({self.num_speculative_tokens}).")

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
                self.draft_parallel_config)
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
            # Validate and set draft token acceptance related settings.

        if (self.draft_token_acceptance_method is None):
            raise ValueError("draft_token_acceptance_method is not set. "
                             "Expected values are rejection_sampler or "
                             "typical_acceptance_sampler.")

        if (self.draft_token_acceptance_method != 'rejection_sampler'
                and self.draft_token_acceptance_method !=
                'typical_acceptance_sampler'):
            raise ValueError(
                "Expected draft_token_acceptance_method to be either "
                "rejection_sampler or typical_acceptance_sampler. Instead it "
                f"is {self.draft_token_acceptance_method}")

        if (self.typical_acceptance_sampler_posterior_threshold < 0
                or self.typical_acceptance_sampler_posterior_alpha < 0):
            raise ValueError(
                "Expected typical_acceptance_sampler_posterior_threshold "
                "and typical_acceptance_sampler_posterior_alpha to be > 0. "
                "Instead found "
                f"typical_acceptance_sampler_posterior_threshold = "
                f"{self.typical_acceptance_sampler_posterior_threshold} and "
                f"typical_acceptance_sampler_posterior_alpha = "
                f"{self.typical_acceptance_sampler_posterior_alpha}")
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def __repr__(self) -> str:
1600
1601
1602
1603
        if self.ngram_prompt_lookup_max > 0:
            draft_model = "[ngram]"
        else:
            draft_model = self.draft_model_config.model
1604
1605
1606
1607
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


1608
1609
1610
1611
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
1612
    fully_sharded_loras: bool = False
1613
    max_cpu_loras: Optional[int] = None
1614
    lora_dtype: Optional[Union[torch.dtype, str]] = None
1615
1616
1617
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256
1618
    long_lora_scaling_factors: Optional[Tuple[float]] = None
1619
1620

    def __post_init__(self):
1621
1622
1623
        # Setting the maximum rank to 256 should be able to satisfy the vast
        # majority of applications.
        possible_max_ranks = (8, 16, 32, 64, 128, 256)
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
1640
                f"max_loras ({self.max_loras})")
1641
1642
1643
1644
1645
1646

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
1647
1648
1649
        if model_config.quantization and model_config.quantization not in [
                "awq", "gptq"
        ]:
1650
            # TODO support marlin
1651
1652
            logger.warning("%s quantization is not tested with LoRA yet.",
                           model_config.quantization)
1653
1654

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
1655
1656
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
1657
1658
        if scheduler_config.chunked_prefill_enabled:
            raise ValueError("LoRA is not supported with chunked prefill yet.")
1659
1660


1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
@dataclass
class PromptAdapterConfig:
    max_prompt_adapters: int
    max_prompt_adapter_token: int
    max_cpu_prompt_adapters: Optional[int] = None
    prompt_adapter_dtype: Optional[torch.dtype] = None

    def __post_init__(self):

        if self.max_prompt_adapters < 1:
            raise ValueError(f"max_prompt_adapters "
                             f"({self.max_prompt_adapters}) must be >= 1.")
        if self.max_prompt_adapter_token == 0:
            raise ValueError("max_prompt_adapter_token must be set.")
        if self.max_cpu_prompt_adapters is None:
            self.max_cpu_prompt_adapters = self.max_prompt_adapters

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.prompt_adapter_dtype in (None, "auto"):
            self.prompt_adapter_dtype = model_config.dtype
        elif isinstance(self.prompt_adapter_dtype, str):
            self.prompt_adapter_dtype = getattr(torch,
                                                self.prompt_adapter_dtype)


1686
@dataclass
1687
class MultiModalConfig:
1688
1689
    """Controls the behavior of multimodal models."""

1690
    limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
1691
1692
1693
1694
1695
    """
    The maximum number of multi-modal input instances allowed per prompt
    for each :class:`~vllm.multimodal.MultiModalPlugin`.
    """

1696
    # TODO: Add configs to init vision tower or not.
1697

1698

1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
@dataclass
class PoolerConfig:
    """Controls the behavior of pooler in embedding model"""

    pooling_type: Optional[str] = None
    pooling_norm: Optional[bool] = None
    pooling_softmax: Optional[bool] = None
    pooling_step_tag_id: Optional[int] = None
    pooling_returned_token_ids: Optional[List[int]] = None


1710
1711
1712
1713
1714
1715
1716
1717
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

1718
_ROCM_NOT_SUPPORTED_DTYPE: List[str] = []  #
1719

1720
1721
1722

def _get_and_verify_dtype(
    config: PretrainedConfig,
1723
    dtype: Union[str, torch.dtype],
1724
1725
1726
1727
1728
1729
1730
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

1731
1732
1733
1734
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
Woosuk Kwon's avatar
Woosuk Kwon committed
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
                if config.model_type == "gemma2":
                    logger.info(
                        "For Gemma 2, we downcast float32 to bfloat16 instead "
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16.")
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
1745
1746
            else:
                torch_dtype = config_dtype
1747
        else:
1748
1749
1750
1751
1752
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
1753
    else:
1754
        raise ValueError(f"Unknown dtype: {dtype}")
1755
1756
1757
1758
1759

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
1760
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
1761
1762
1763
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
1764
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
1765
1766
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
1767
            # Casting between float16 and bfloat16 is allowed with a warning.
1768
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
1769
1770

    return torch_dtype
1771
1772
1773
1774
1775


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
1776
    disable_sliding_window: bool,
1777
    sliding_window_len: Optional[Union[int, List[Optional[int]]]],
1778
    spec_target_max_model_len: Optional[int] = None,
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
1789
1790
        # ChatGLM2
        "seq_length",
1791
1792
        # Command-R
        "model_max_length",
1793
1794
1795
1796
1797
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
1798
    # Choose the smallest "max_length" from the possible keys.
1799
    max_len_key = None
1800
    for key in possible_keys:
1801
1802
1803
1804
1805
        max_len = getattr(hf_config, key, None)
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
1806
1807
1808
1809

    # If sliding window is manually disabled, max_length should be less
    # than the sliding window length in the model config.
    if disable_sliding_window and sliding_window_len is not None:
1810
1811

        sliding_window_len_min = get_min_sliding_window(sliding_window_len)
1812
        max_len_key = "sliding_window" \
1813
1814
1815
            if sliding_window_len_min < derived_max_model_len else max_len_key
        derived_max_model_len = min(derived_max_model_len,
                                    sliding_window_len_min)
1816
1817
1818

    # If none of the keys were found in the config, use a default and
    # log a warning.
1819
    if derived_max_model_len == float("inf"):
1820
1821
1822
1823
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

1824
1825
1826
1827
1828
        if spec_target_max_model_len is not None:
            # If this is a speculative draft model, we use the max model len
            # from the target model.
            return spec_target_max_model_len

1829
1830
1831
1832
        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
1833
            "%s. Assuming the model's maximum length is %d.", possible_keys,
1834
            default_max_len)
1835
        derived_max_model_len = default_max_len
1836

1837
    rope_scaling = getattr(hf_config, "rope_scaling", None)
1838
    if rope_scaling is not None:
1839
1840
1841
        # No need to consider "type" key because of patch_rope_scaling when
        # loading HF config
        rope_type = rope_scaling["rope_type"]
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851

        if rope_type not in ("su", "longrope", "llama3"):
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that supports rope_scaling
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "with rope_scaling. Please raise an issue so we can "
                    "investigate.")

1852
1853
1854
1855
            # NOTE: rope_type == "default" does not define factor
            # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
            scaling_factor = rope_scaling.get("factor", 1.0)

1856
1857
1858
1859
            if rope_type == "yarn":
                derived_max_model_len = rope_scaling[
                    "original_max_position_embeddings"]
            derived_max_model_len *= scaling_factor
1860

1861
1862
    # If the user specified a max length, make sure it is smaller than the
    # derived length from the HF model config.
1863
    if max_model_len is None:
1864
        max_model_len = int(derived_max_model_len)
1865
    elif max_model_len > derived_max_model_len:
1866
1867
1868
1869
1870
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, "model_max_length", None)
        if model_max_length is not None and max_model_len <= model_max_length:
1871
1872
1873
1874
1875
1876
1877
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that has model_max_length
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "model_max_length in the config. Please raise an issue "
                    "so we can investigate.")
1878
        else:
1879
            msg = (
1880
                f"User-specified max_model_len ({max_model_len}) is greater "
1881
1882
                f"than the derived max_model_len ({max_len_key}="
                f"{derived_max_model_len} or model_max_length="
1883
                f"{model_max_length} in model's config.json). This may lead "
1884
1885
1886
1887
1888
1889
1890
1891
1892
                "to incorrect model outputs or CUDA errors.")
            if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
                logger.warning(
                    "%s Make sure the value is correct and within the "
                    "model context size.", msg)
            else:
                raise ValueError(
                    f"{msg} To allow overriding this maximum, set "
                    "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
1893
    return int(max_model_len)
1894
1895


1896
1897
1898
1899
1900
1901
1902
1903
def get_min_sliding_window(
        sliding_window: Union[int, List[Optional[int]]]) -> int:
    if isinstance(sliding_window, list):
        return min(s for s in sliding_window if s is not None)

    return sliding_window


1904
1905
1906
def get_served_model_name(model: str,
                          served_model_name: Optional[Union[str, List[str]]]):
    """
1907
1908
1909
1910
    If the input is a non-empty list, the first model_name in
    `served_model_name` is taken.
    If the input is a non-empty string, it is used directly.
    For cases where the input is either an empty string or an
1911
1912
1913
1914
1915
1916
1917
1918
1919
    empty list, the fallback is to use `self.model`.
    """
    if not served_model_name:
        return model
    if isinstance(served_model_name, list):
        return served_model_name[0]
    return served_model_name


1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
@dataclass
class DecodingConfig:
    """Dataclass which contains the decoding strategy of the engine"""

    # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
    guided_decoding_backend: str = 'outlines'

    def __post_init__(self):
        valid_guided_backends = ['outlines', 'lm-format-enforcer']
        backend = self.guided_decoding_backend
        if backend not in valid_guided_backends:
            raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                             f"must be one of {valid_guided_backends}")


1935
1936
1937
1938
1939
@dataclass
class ObservabilityConfig:
    """Configuration for observability."""
    otlp_traces_endpoint: Optional[str] = None

1940
1941
1942
1943
1944
1945
1946
1947
    # Collecting detailed timing information for each request can be expensive.

    # If set, collects the model forward time for the request.
    collect_model_forward_time: bool = False

    # If set, collects the model execute time for the request.
    collect_model_execute_time: bool = False

1948
    def __post_init__(self):
1949
1950
1951
1952
1953
        if not is_otel_available() and self.otlp_traces_endpoint is not None:
            raise ValueError(
                "OpenTelemetry is not available. Unable to configure "
                "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
                f"installed. Original error:\n{otel_import_error_traceback}")
1954
1955


1956
1957
1958
@dataclass
class VllmConfig:
    """Dataclass which contains all vllm-related configuration. This
1959
1960
1961
1962
1963
1964
1965
1966
    simplifies passing around the distinct configurations in the codebase.
    """

    model_config: ModelConfig
    cache_config: CacheConfig
    parallel_config: ParallelConfig
    scheduler_config: SchedulerConfig
    device_config: DeviceConfig
1967
    load_config: LoadConfig
1968
1969
1970
1971
1972
    lora_config: Optional[LoRAConfig] = None
    speculative_config: Optional[SpeculativeConfig] = None
    decoding_config: Optional[DecodingConfig] = None
    observability_config: Optional[ObservabilityConfig] = None
    prompt_adapter_config: Optional[PromptAdapterConfig] = None
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
    quant_config: Optional[QuantizationConfig] = None

    @staticmethod
    def _get_quantization_config(
            model_config: ModelConfig,
            load_config: LoadConfig) -> Optional[QuantizationConfig]:
        """Get the quantization config."""
        if model_config.quantization is not None:
            from vllm.model_executor.model_loader.weight_utils import (
                get_quant_config)
            quant_config = get_quant_config(model_config, load_config)
            capability_tuple = current_platform.get_device_capability()

            if capability_tuple is not None:
                capability = capability_tuple.to_int()
                if capability < quant_config.get_min_capability():
                    raise ValueError(
                        f"The quantization method {model_config.quantization} "
                        "is not supported for the current GPU. Minimum "
                        f"capability: {quant_config.get_min_capability()}. "
                        f"Current capability: {capability}.")
            supported_dtypes = quant_config.get_supported_act_dtypes()
            if model_config.dtype not in supported_dtypes:
                raise ValueError(
                    f"{model_config.dtype} is not supported for quantization "
                    f"method {model_config.quantization}. Supported dtypes: "
                    f"{supported_dtypes}")
            return quant_config
        return None
2002
2003
2004
2005

    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
2006
2007
2008
        self.model_config.verify_async_output_proc(self.parallel_config,
                                                   self.speculative_config,
                                                   self.device_config)
2009
2010
2011
2012
2013
2014
2015
        self.model_config.verify_with_parallel_config(self.parallel_config)
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
2016
2017
2018
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
2019
2020
2021
2022
2023

        if self.quant_config is None and \
            self.model_config is not None and self.load_config is not None:
            self.quant_config = VllmConfig._get_quantization_config(
                self.model_config, self.load_config)