config.py 84.3 KB
Newer Older
1
import enum
2
import json
3
from dataclasses import dataclass, field, fields
4
5
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
                    Optional, Tuple, Type, Union)
6
7

import torch
8
from transformers import PretrainedConfig
9

10
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
from vllm.model_executor.models import ModelRegistry
14
from vllm.platforms import current_platform
15
from vllm.tracing import is_otel_available, otel_import_error_traceback
16
from vllm.transformers_utils.config import (ConfigFormat, get_config,
17
18
                                            get_hf_image_processor_config,
                                            get_hf_text_config)
19
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
20
                        is_hip, is_neuron, is_openvino, is_xpu,
21
                        print_warning_once)
22

23
24
25
if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

26
    from vllm.executor.executor_base import ExecutorBase
27
    from vllm.model_executor.model_loader.loader import BaseModelLoader
28
29
    from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
        BaseTokenizerGroup)
30

31
32
logger = init_logger(__name__)

33
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
34
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
35

36
37

class ModelConfig:
38
39
40
41
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
42
43
            It is also used as the content for `model_name` tag in metrics 
            output when `served_model_name` is not specified. 
44
        tokenizer: Name or path of the huggingface tokenizer to use.
45
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
46
47
            available, "slow" will always use the slow tokenizer, and
            "mistral" will always use the tokenizer from `mistral_common`.
48
49
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
50
51
52
53
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
54
55
56
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
57
        code_revision: The specific revision to use for the model code on
58
            Hugging Face Hub. It can be a branch name, a tag name, or a
59
            commit id. If unspecified, will use the default version.
60
61
62
        rope_scaling: Dictionary containing the scaling configuration for the
            RoPE embeddings. When using this flag, don't update
            `max_position_embeddings` to the expected new maximum.
63
64
65
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
66
67
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
68
69
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
70
71
        quantization_param_path: Path to JSON file containing scaling factors.
            Used to load KV cache scaling factors into the model when KV cache
72
73
            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
            be used to load activation and weight scaling factors when the
74
            model dtype is FP8_E4M3 on ROCm.
75
76
77
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
78
            If None, the user did not specify, so default to False.
79
80
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
81
82
83
            to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
84
85
86
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
87
88
89
90
        disable_sliding_window: Whether to disable sliding window. If True,
            we will disable the sliding window functionality of the model.
            If the model does not support sliding window, this argument is
            ignored.
91
92
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer.
93
94
95
96
        served_model_name: The model name used in metrics tag `model_name`,
            matches the model name exposed via the APIs. If multiple model 
            names provided, the first name will be used. If not specified, 
            the model name will be the same as `model`.
97
98
        limit_mm_per_prompt: Maximum number of data instances per modality 
            per prompt. Only applicable for multimodal models.
99
100
101
102
        override_neuron_config: Initialize non default neuron config or 
            override default neuron config that are specific to Neuron devices, 
            this argument will be used to configure the neuron config that 
            can not be gathered from the vllm arguments. 
103
104
        config_format: The config format which shall be loaded.
            Defaults to 'auto' which defaults to 'hf'.
105
106
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor.
107
    """
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    def __init__(self,
                 model: str,
                 tokenizer: str,
                 tokenizer_mode: str,
                 trust_remote_code: bool,
                 dtype: Union[str, torch.dtype],
                 seed: int,
                 revision: Optional[str] = None,
                 code_revision: Optional[str] = None,
                 rope_scaling: Optional[dict] = None,
                 rope_theta: Optional[float] = None,
                 tokenizer_revision: Optional[str] = None,
                 max_model_len: Optional[int] = None,
                 spec_target_max_model_len: Optional[int] = None,
                 quantization: Optional[str] = None,
                 quantization_param_path: Optional[str] = None,
                 enforce_eager: Optional[bool] = None,
                 max_context_len_to_capture: Optional[int] = None,
                 max_seq_len_to_capture: Optional[int] = None,
                 max_logprobs: int = 20,
                 disable_sliding_window: bool = False,
                 skip_tokenizer_init: bool = False,
                 served_model_name: Optional[Union[str, List[str]]] = None,
                 limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
                 use_async_output_proc: bool = True,
                 override_neuron_config: Optional[Dict[str, Any]] = None,
135
136
                 config_format: ConfigFormat = ConfigFormat.AUTO,
                 mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
137
        self.model = model
138
        self.tokenizer = tokenizer
139
        self.tokenizer_mode = tokenizer_mode
140
        self.trust_remote_code = trust_remote_code
141
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
142
        self.revision = revision
143
        self.code_revision = code_revision
144
        self.rope_scaling = rope_scaling
145
        self.rope_theta = rope_theta
146
147
148
149
150
        # The tokenizer version is consistent with the model version by default.
        if tokenizer_revision is None:
            self.tokenizer_revision = revision
        else:
            self.tokenizer_revision = tokenizer_revision
151
        self.quantization = quantization
152
        self.quantization_param_path = quantization_param_path
153
        self.enforce_eager = enforce_eager
154
        if max_context_len_to_capture is not None:
155
156
            raise ValueError("`max_context_len_to_capture` is deprecated. "
                             "Use `max_seq_len_to_capture` instead.")
157
        self.max_seq_len_to_capture = max_seq_len_to_capture
158
        self.max_logprobs = max_logprobs
159
        self.disable_sliding_window = disable_sliding_window
160
        self.skip_tokenizer_init = skip_tokenizer_init
161

162
        self.hf_config = get_config(self.model, trust_remote_code, revision,
163
164
                                    code_revision, rope_scaling, rope_theta,
                                    config_format)
165
        self.hf_text_config = get_hf_text_config(self.hf_config)
166
167
        self.hf_image_processor_config = get_hf_image_processor_config(
            self.model, revision)
168
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
169
        self.use_async_output_proc = use_async_output_proc
170
        self.mm_processor_kwargs = mm_processor_kwargs
Woosuk Kwon's avatar
Woosuk Kwon committed
171

172
173
        # Set enforce_eager to False if the value is unset.
        if self.enforce_eager is None:
174
175
            self.enforce_eager = False

176
177
178
179
180
181
182
183
184
        sliding_window = getattr(self.hf_text_config, "sliding_window", None)
        has_interleaved_attention = (sliding_window is not None) and (
            isinstance(sliding_window, list) or
            (self.hf_text_config.model_type in ["gemma2"]))

        if (not self.disable_sliding_window and has_interleaved_attention):
            sliding_window_len_min = get_min_sliding_window(
                self.hf_text_config.sliding_window)

Woosuk Kwon's avatar
Woosuk Kwon committed
185
            print_warning_once(
186
                f"{self.hf_text_config.model_type} has interleaved attention, "
Woosuk Kwon's avatar
Woosuk Kwon committed
187
188
                "which is currently not supported by vLLM. Disabling sliding "
                "window and capping the max length to the sliding window size "
189
                f"({sliding_window_len_min}).")
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
            self.disable_sliding_window = True

192
193
194
195
        self.max_model_len = _get_and_verify_max_len(
            hf_config=self.hf_text_config,
            max_model_len=max_model_len,
            disable_sliding_window=self.disable_sliding_window,
196
197
            sliding_window_len=self.get_hf_config_sliding_window(),
            spec_target_max_model_len=spec_target_max_model_len)
198
199
        self.served_model_name = get_served_model_name(model,
                                                       served_model_name)
200
201
        self.multimodal_config = self._init_multimodal_config(
            limit_mm_per_prompt)
202
203
        if not self.skip_tokenizer_init:
            self._verify_tokenizer_mode()
204

205
206
207
        self.is_attention_free = self._init_attention_free()
        self.has_inner_state = self._init_has_inner_state()

208
209
        self.override_neuron_config = override_neuron_config if is_neuron(
        ) else None
210
        self._verify_embedding_mode()
211
        self._verify_quantization()
212
        self._verify_cuda_graph()
213
        self._verify_bnb_config()
214

215
216
217
218
    def _init_multimodal_config(
        self, limit_mm_per_prompt: Optional[Mapping[str, int]]
    ) -> Optional["MultiModalConfig"]:
        architectures = getattr(self.hf_config, "architectures", [])
219
        if ModelRegistry.is_multimodal_model(architectures):
220
            return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
221
222
223
224
225
226

        if limit_mm_per_prompt:
            raise ValueError("`limit_mm_per_prompt` is only supported for "
                             "multimodal models.")

        return None
227

228
229
230
231
232
233
234
    def _init_attention_free(self) -> bool:
        architectures = getattr(self.hf_config, "architectures", [])
        return ModelRegistry.is_attention_free_model(architectures)

    def _init_has_inner_state(self) -> bool:
        architectures = getattr(self.hf_config, "architectures", [])
        return ModelRegistry.model_has_inner_state(architectures)
235

236
237
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
238
        if tokenizer_mode not in ["auto", "slow", "mistral"]:
239
240
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
241
                "either 'auto', 'slow' or 'mistral'.")
242
        self.tokenizer_mode = tokenizer_mode
243

244
245
    def _verify_embedding_mode(self) -> None:
        architectures = getattr(self.hf_config, "architectures", [])
246
247
248
249
250
251
252
253
254
255

        # TODO: Allow the same model architecture to be specified as either
        # generation or embedding model
        if "Phi3VForCausalLM" in architectures:
            # Match both remote and local names
            embedding_mode = "/VLM2Vec" in self.model
        else:
            embedding_mode = ModelRegistry.is_embedding_model(architectures)

        self.embedding_mode = embedding_mode
256

257
258
259
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
260
            # compressed-tensors uses a "compression_config" key
261
            quant_cfg = getattr(self.hf_config, "compression_config", None)
262
263
        return quant_cfg

264
    def _verify_quantization(self) -> None:
265
        supported_quantization = [*QUANTIZATION_METHODS]
266
        rocm_supported_quantization = [
zhuwenwen's avatar
zhuwenwen committed
267
            "awq", "gptq", "compressed-tensors"
268
        ]
zhuwenwen's avatar
zhuwenwen committed
269
270
271
272
        # rocm_supported_quantization = [
        #     "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
        #     "fbgemm_fp8"
        # ]
273
        optimized_quantization_methods = [
274
275
276
            "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
            "awq_marlin", "fbgemm_fp8", "compressed_tensors",
            "compressed-tensors", "experts_int8"
277
        ]
278
        tpu_supported_quantization = ["tpu_int8"]
279
        neuron_supported_quantization = ["neuron_quant"]
280
281
282
283
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
284
285
        quant_cfg = self._parse_quant_hf_config()

286
287
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
288
289

            # Detect which checkpoint is it
290
            for _, method in QUANTIZATION_METHODS.items():
291
292
293
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization)
                if quantization_override:
zhuwenwen's avatar
zhuwenwen committed
294
295
296
297
298
299
300
301
302
                    if is_hip():
                        if quantization_override in rocm_supported_quantization:
                            quant_method = quantization_override
                            self.quantization = quantization_override
                            break
                    else:
                        quant_method = quantization_override
                        self.quantization = quantization_override
                        break
303

304
            # Verify quantization configurations.
305
            if self.quantization is None:
306
307
                self.quantization = quant_method
            elif self.quantization != quant_method:
308
309
                raise ValueError(
                    "Quantization method specified in the model config "
310
                    f"({quant_method}) does not match the quantization "
311
312
313
314
315
316
317
318
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
319
            if is_hip(
320
            ) and self.quantization not in rocm_supported_quantization:
321
                raise ValueError(
322
323
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm.")
324
            if current_platform.is_tpu(
325
326
327
328
            ) and self.quantization not in tpu_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in TPU Backend.")
329
            if self.quantization not in optimized_quantization_methods:
330
                logger.warning(
331
                    "%s quantization is not fully "
332
                    "optimized yet. The speed can be slower than "
333
                    "non-quantized models.", self.quantization)
334
335
336
337
338
339
            if (self.quantization == "awq" and is_hip()
                    and not envs.VLLM_USE_TRITON_AWQ):
                logger.warning(
                    "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
                    " is not set, enabling VLLM_USE_TRITON_AWQ.")
                envs.VLLM_USE_TRITON_AWQ = True
340
341
342
343
344
            if is_neuron(
            ) and self.quantization not in neuron_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in Neuron Backend.")
345

346
    def _verify_cuda_graph(self) -> None:
347
348
349
350
        if self.max_seq_len_to_capture is None:
            self.max_seq_len_to_capture = self.max_model_len
        self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
                                          self.max_model_len)
351

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    def _verify_bnb_config(self) -> None:
        """
        The current version of bitsandbytes (0.44.0) with 8-bit models does not 
        yet support CUDA graph.
        """
        is_bitsandbytes = self.quantization == "bitsandbytes"
        has_quantization_config = (getattr(self.hf_config,
                                           "quantization_config", None)
                                   is not None)
        is_8bit = (self.hf_config.quantization_config.get(
            "load_in_8bit", False) if has_quantization_config else False)
        if all([
                is_bitsandbytes,
                has_quantization_config,
                is_8bit,
                not self.enforce_eager,
        ]):
            logger.warning(
                "CUDA graph is not supported on BitAndBytes 8bit yet, "
                "fallback to the eager mode.")
            self.enforce_eager = True

374
375
376
377
378
379
380
381
382
383
384
385
    def verify_async_output_proc(self, parallel_config, speculative_config,
                                 device_config) -> None:
        if not self.use_async_output_proc:
            # Nothing to check
            return

        if parallel_config.pipeline_parallel_size > 1:
            logger.warning("Async output processing can not be enabled "
                           "with pipeline parallel")
            self.use_async_output_proc = False
            return

386
387
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
388
        if device_config.device_type not in ("cuda", "tpu", "xpu"):
389
            logger.warning(
390
                "Async output processing is only supported for CUDA, TPU, XPU. "
391
                "Disabling it for other platforms.")
392
393
394
395
396
397
398
399
400
            self.use_async_output_proc = False
            return

        if envs.VLLM_USE_RAY_SPMD_WORKER:
            logger.warning(
                "Async output processing can not be enabled with ray spmd")
            self.use_async_output_proc = False
            return

401
402
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
403
        if device_config.device_type == "cuda" and self.enforce_eager:
404
405
406
407
408
409
410
411
412
413
414
415
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            self.use_async_output_proc = not self.enforce_eager
            return

        # Async postprocessor is not necessary with embedding mode
        # since there is no token generation
        if self.embedding_mode:
            self.use_async_output_proc = False

416
417
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
418
419
420
421
422
        if speculative_config:
            logger.warning("Async output processing is not supported with"
                           " speculative decoding currently.")
            self.use_async_output_proc = False

423
424
425
426
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
427
428
        total_num_attention_heads = getattr(self.hf_text_config,
                                            "num_attention_heads", 0)
429
430
431
432
433
434
435
436
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        pipeline_parallel_size = parallel_config.pipeline_parallel_size
437
438
439
440
441
442
        if pipeline_parallel_size > 1:
            architectures = getattr(self.hf_config, "architectures", [])
            if not ModelRegistry.is_pp_supported_model(architectures):
                raise NotImplementedError(
                    "Pipeline parallelism is not supported for this model. "
                    "Supported models implement the `SupportsPP` interface.")
443

444
445
446
447
            if self.use_async_output_proc:
                logger.warning("Async output processor is not supported with "
                               "pipeline parallelism currently. Disabling it.")
                self.use_async_output_proc = False
448

449
450
    def get_hf_config_sliding_window(
            self) -> Union[Optional[int], List[Optional[int]]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
451
        """Get the sliding window size, or None if disabled."""
452
453
454
455

        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
        # addition to sliding window size. We check if that field is present
        # and if it's False, return None.
456
457
        if (hasattr(self.hf_text_config, "use_sliding_window")
                and not self.hf_text_config.use_sliding_window):
458
            return None
459
        return getattr(self.hf_text_config, "sliding_window", None)
460

461
    def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
462
463
464
465
466
467
468
469
        """Get the sliding window size, or None if disabled.
        """
        # If user disables sliding window, return None.
        if self.disable_sliding_window:
            return None
        # Otherwise get the value from the hf config.
        return self.get_hf_config_sliding_window()

470
    def get_vocab_size(self) -> int:
471
        return self.hf_text_config.vocab_size
472

473
    def get_hidden_size(self) -> int:
474
        return self.hf_text_config.hidden_size
475
476

    def get_head_size(self) -> int:
wangding zeng's avatar
wangding zeng committed
477
478
479
480
481
482
        # TODO remove hard code
        if hasattr(self.hf_text_config, "model_type"
                   ) and self.hf_text_config.model_type == 'deepseek_v2':
            # FlashAttention supports only head_size 32, 64, 128, 256,
            # we need to pad head_size 192 to 256
            return 256
483
484
485
486

        if self.is_attention_free:
            return 0

487
488
        if hasattr(self.hf_text_config, "head_dim"):
            return self.hf_text_config.head_dim
489
        # FIXME(woosuk): This may not be true for all models.
490
491
        return (self.hf_text_config.hidden_size //
                self.hf_text_config.num_attention_heads)
492

493
494
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
495
        # For GPTBigCode & Falcon:
496
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
497
498
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
499
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
500
        new_decoder_arch_falcon = (
501
            self.hf_config.model_type in falcon_model_types
502
            and getattr(self.hf_config, "new_decoder_architecture", False))
503
        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
504
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
505
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
506
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
507
            return 1
508

509
        # For DBRX and MPT
510
511
512
513
514
        if self.hf_config.model_type == "mpt":
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type == "dbrx":
515
516
517
            return getattr(self.hf_config.attn_config, "kv_n_heads",
                           self.hf_config.num_attention_heads)

518
519
520
        if self.is_attention_free:
            return 0

521
522
523
524
525
526
527
528
529
530
        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
531
            num_kv_heads = getattr(self.hf_text_config, attr, None)
532
533
534
535
536
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
537
        return self.hf_text_config.num_attention_heads
538
539
540
541
542
543
544
545
546
547

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
548

549
550
    def get_num_attention_heads(self,
                                parallel_config: "ParallelConfig") -> int:
551
552
        num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
        return num_heads // parallel_config.tensor_parallel_size
553

554
    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
555
        from vllm.distributed.utils import get_pp_indices
Mor Zusman's avatar
Mor Zusman committed
556
557
        total_num_hidden_layers = getattr(self.hf_text_config,
                                          "num_hidden_layers", 0)
558
559
560
561
        pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
        pp_size = parallel_config.pipeline_parallel_size
        start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
        return end - start
562

563
564
565
566
    def get_num_attention_layers(self,
                                 parallel_config: "ParallelConfig") -> int:
        if self.is_attention_free:
            return 0
Mor Zusman's avatar
Mor Zusman committed
567
568
569

        num_layers = self.get_num_layers(parallel_config)

570
571
572
573
        # Transformers supports layers_block_type @property
        layers = getattr(self.hf_config, "layers_block_type",
                         ["attention"] * num_layers)
        return len([t for t in layers if t == "attention"])
Mor Zusman's avatar
Mor Zusman committed
574

575
576
577
578
579
580
581
582
583
584
585
586
    def get_multimodal_config(self) -> "MultiModalConfig":
        """
        Get the multimodal configuration of the model.

        Raises:
            ValueError: If the model is not multimodal.
        """
        if self.multimodal_config is None:
            raise ValueError("The model is not multimodal.")

        return self.multimodal_config

587
588
589
    @property
    def is_encoder_decoder_model(self) -> bool:
        """Extract the HF encoder/decoder model flag."""
590
591
592
        return getattr(self.hf_config, "is_encoder_decoder", False) or (
            (hasattr(self.hf_config, "text_config") and getattr(
                self.hf_config.text_config, "is_encoder_decoder", False)))
593
594
595
596
597
598

    @property
    def is_embedding_model(self) -> bool:
        """Extract the embedding model flag."""
        return self.embedding_mode

599
600
601
602
    @property
    def is_multimodal_model(self) -> bool:
        return self.multimodal_config is not None

603
604

class CacheConfig:
605
606
607
608
609
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
610
            vLLM execution.
611
        swap_space: Size of the CPU swap space per GPU (in GiB).
612
        cache_dtype: Data type for kv cache storage.
613
        num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
614
            profiled num_gpu_blocks if specified. Does nothing if None.
615
    """
616

617
618
619
620
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
621
        swap_space: float,
622
        cache_dtype: str,
623
        is_attention_free: bool = False,
624
        num_gpu_blocks_override: Optional[int] = None,
625
        sliding_window: Optional[int] = None,
626
        enable_prefix_caching: bool = False,
627
        cpu_offload_gb: float = 0,
628
629
630
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
631
        self.swap_space_bytes = swap_space * GiB_bytes
632
        self.num_gpu_blocks_override = num_gpu_blocks_override
633
        self.cache_dtype = cache_dtype
634
        self.is_attention_free = is_attention_free
635
        self.sliding_window = sliding_window
636
        self.enable_prefix_caching = enable_prefix_caching
637
        self.cpu_offload_gb = cpu_offload_gb
638

639
        self._verify_args()
640
        self._verify_cache_dtype()
641
        self._verify_prefix_caching()
642
643

        # Will be set after profiling.
644
645
        self.num_gpu_blocks: Optional[int] = None
        self.num_cpu_blocks: Optional[int] = None
646

647
    def metrics_info(self):
648
649
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
650
651
        return {key: str(value) for key, value in self.__dict__.items()}

652
653
654
655
656
657
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

658
659
660
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
661
        elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
662
            logger.info(
663
664
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
665
666
                "Meanwhile, it may cause accuracy drop without a proper "
                "scaling factor")
667
668
669
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

670
671
672
673
674
675
676
677
678
    def _verify_prefix_caching(self) -> None:
        if not self.enable_prefix_caching:
            return

        if self.sliding_window is not None:
            raise NotImplementedError(
                "Prefix caching is not supported with sliding window. "
                "Run with --disable-sliding-window to use prefix caching.")

679
680
681
682
683
684
685
686
687
688
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

689
690
691
        msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
               f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
               "is allocated for the swap space.")
692
693
694
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
695
            logger.warning("Possibly too large swap space. %s", msg)
696

697

698
699
700
@dataclass
class TokenizerPoolConfig:
    """Configuration for the tokenizer pool.
701

702
703
704
705
706
707
708
709
    Args:
        pool_size: Number of tokenizer workers in the pool.
        pool_type: Type of the pool.
        extra_config: Additional config for the pool.
            The way the config will be used depends on the
            pool type.
    """
    pool_size: int
710
    pool_type: Union[str, Type["BaseTokenizerGroup"]]
711
712
713
    extra_config: dict

    def __post_init__(self):
714
715
        if self.pool_type not in ("ray", ) and not isinstance(
                self.pool_type, type):
716
717
718
719
720
721
            raise ValueError(f"Unknown pool type: {self.pool_type}")
        if not isinstance(self.extra_config, dict):
            raise ValueError("extra_config must be a dictionary.")

    @classmethod
    def create_config(
722
723
        cls, tokenizer_pool_size: int,
        tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]],
724
725
726
        tokenizer_pool_extra_config: Optional[Union[str, dict]]
    ) -> Optional["TokenizerPoolConfig"]:
        """Create a TokenizerPoolConfig from the given parameters.
727

728
        If tokenizer_pool_size is 0, return None.
729

730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        Args:
            tokenizer_pool_size: Number of tokenizer workers in the pool.
            tokenizer_pool_type: Type of the pool.
            tokenizer_pool_extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type. This can be a JSON string (will be parsed).
        """
        if tokenizer_pool_size:
            if isinstance(tokenizer_pool_extra_config, str):
                tokenizer_pool_extra_config_parsed = json.loads(
                    tokenizer_pool_extra_config)
            else:
                tokenizer_pool_extra_config_parsed = (
                    tokenizer_pool_extra_config or {})
            tokenizer_pool_config = cls(tokenizer_pool_size,
                                        tokenizer_pool_type,
                                        tokenizer_pool_extra_config_parsed)
        else:
            tokenizer_pool_config = None
        return tokenizer_pool_config


752
753
754
755
756
757
758
class LoadFormat(str, enum.Enum):
    AUTO = "auto"
    PT = "pt"
    SAFETENSORS = "safetensors"
    NPCACHE = "npcache"
    DUMMY = "dummy"
    TENSORIZER = "tensorizer"
759
    SHARDED_STATE = "sharded_state"
760
    GGUF = "gguf"
761
    BITSANDBYTES = "bitsandbytes"
762
    MISTRAL = "mistral"
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781


@dataclass
class LoadConfig:
    """
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
            "tensorizer" will use CoreWeave's tensorizer library for
                fast weight loading.
782
            "bitsandbytes" will load nf4 type weights.
783
784
785
        ignore_patterns: The list of patterns to ignore when loading the model.
            Default to "original/**/*" to avoid repeated loading of llama's 
            checkpoints.
786

787
788
789
790
791
792
    """

    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(
        default_factory=dict)
793
    ignore_patterns: Optional[Union[List[str], str]] = None
794
795
796
797
798
799
800
801

    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = json.loads(
                model_loader_extra_config)
        self._verify_load_format()

802
803
804
805
806
807
808
        if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
            logger.info(
                "Ignoring the following patterns when downloading weights: %s",
                self.ignore_patterns)
        else:
            self.ignore_patterns = ["original/**/*"]

809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
    def _verify_load_format(self) -> None:
        if not isinstance(self.load_format, str):
            return

        load_format = self.load_format.lower()
        self.load_format = LoadFormat(load_format)

        rocm_not_supported_load_format: List[str] = []
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f for f in LoadFormat.__members__
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format '{load_format}' is not supported in ROCm. "
                f"Supported load formats are "
                f"{rocm_supported_load_format}")


828
class ParallelConfig:
829
830
831
832
833
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
834
        worker_use_ray: Deprecated, use distributed_executor_backend instead.
zspo's avatar
zspo committed
835
836
837
        max_parallel_loading_workers: Maximum number of multiple batches
            when load model sequentially. To avoid RAM OOM when using tensor
            parallel and large models.
838
839
        disable_custom_all_reduce: Disable the custom all-reduce kernel and
            fall back to NCCL.
840
841
        tokenizer_pool_config: Config for the tokenizer pool.
            If None, will use synchronous tokenization.
842
843
        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
844
        placement_group: ray distributed model workers placement group.
845
846
847
848
        distributed_executor_backend: Backend to use for distributed model
            workers, either "ray" or "mp" (multiprocessing). If either
            pipeline_parallel_size or tensor_parallel_size is greater than 1,
            will default to "ray" if Ray is installed or "mp" otherwise.
849
    """
850

851
852
853
854
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
855
        worker_use_ray: Optional[bool] = None,
856
        max_parallel_loading_workers: Optional[int] = None,
857
        disable_custom_all_reduce: bool = False,
858
        tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
859
        ray_workers_use_nsight: bool = False,
860
        placement_group: Optional["PlacementGroup"] = None,
861
862
        distributed_executor_backend: Optional[Union[
            str, Type["ExecutorBase"]]] = None,
863
864
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
865
        self.tensor_parallel_size = tensor_parallel_size
866
        self.distributed_executor_backend = distributed_executor_backend
867
        self.max_parallel_loading_workers = max_parallel_loading_workers
868
        self.disable_custom_all_reduce = disable_custom_all_reduce
869
        self.tokenizer_pool_config = tokenizer_pool_config
870
        self.ray_workers_use_nsight = ray_workers_use_nsight
871
        self.placement_group = placement_group
872
        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
873

874
875
876
        if worker_use_ray:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
877
            elif not self.use_ray:
878
879
880
881
                raise ValueError(f"worker-use-ray can't be used with "
                                 f"distributed executor backend "
                                 f"'{self.distributed_executor_backend}'.")

882
883
884
885
886
887
888
        if current_platform.is_tpu() and self.world_size > 1:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
            if self.distributed_executor_backend != "ray":
                raise ValueError(
                    "TPU backend only supports Ray for distributed inference.")

889
        if self.distributed_executor_backend is None and self.world_size > 1:
890
891
892
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

893
            from vllm.executor import ray_utils
894
            backend = "mp"
895
            ray_found = ray_utils.ray_is_available()
896
            if (current_platform.is_cuda()
897
                    and cuda_device_count_stateless() < self.world_size):
898
899
                if not ray_found:
                    raise ValueError("Unable to load Ray which is "
900
901
902
                                     "required for multi-node inference, "
                                     "please install Ray with `pip install "
                                     "ray`.") from ray_utils.ray_import_err
903
904
                backend = "ray"
            elif ray_found:
905
                if self.placement_group:
906
                    backend = "ray"
907
908
909
910
911
912
                else:
                    from ray import is_initialized as ray_is_initialized
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
                        if get_current_placement_group():
                            backend = "ray"
913
914
915
            self.distributed_executor_backend = backend
            logger.info("Defaulting to use %s for distributed inference",
                        backend)
916

917
        self._verify_args()
918
        self.rank: int = 0
919

920
921
922
923
924
925
    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
            and self.distributed_executor_backend.uses_ray)

926
    def _verify_args(self) -> None:
927
928
929
930
931
932
933
        # Lazy import to avoid circular import
        from vllm.executor.executor_base import ExecutorBase

        if self.distributed_executor_backend not in (
                "ray", "mp", None) and not (isinstance(
                    self.distributed_executor_backend, type) and issubclass(
                        self.distributed_executor_backend, ExecutorBase)):
934
            raise ValueError(
935
936
937
938
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
                "values are 'ray', 'mp' or custom ExecutorBase subclass.")
        if self.use_ray:
939
940
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()
941
942
943
944
945
        if is_hip():
            self.disable_custom_all_reduce = True
            logger.info(
                "Disabled the custom all-reduce kernel because it is not "
                "supported on AMD GPUs.")
946
        if self.ray_workers_use_nsight and not self.use_ray:
947
948
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
949

950
951

class SchedulerConfig:
952
953
954
955
956
957
958
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
959
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
960
            and generated text).
961
962
963
964
        num_lookahead_slots: The number of slots to allocate per sequence per
            step, beyond the known token ids. This is used in speculative
            decoding to store KV activations of tokens which may or may not be
            accepted.
965
966
        delay_factor: Apply a delay (of delay factor multiplied by previous
            prompt latency) before scheduling next prompt.
967
968
        enable_chunked_prefill: If True, prefill requests can be chunked based
            on the remaining max_num_batched_tokens.
969
        embedding_mode: Whether the running model is for embedding.
970
971
972
973
974
975
        preemption_mode: Whether to perform preemption by swapping or 
            recomputation. If not specified, we determine the mode as follows:
            We use recomputation by default since it incurs lower overhead than
            swapping. However, when the sequence group has multiple sequences
            (e.g., beam search), recomputation is not currently supported. In
            such a case, we use swapping instead.
976
977
978
979
        send_delta_data: Private API. If used, scheduler sends delta data to
            workers instead of an entire data. It should be enabled only
            when SPMD worker architecture is enabled. I.e.,
            VLLM_USE_RAY_SPMD_WORKER=1
980
        policy: The scheduling policy to use. "fcfs" (default) or "priority".
981
    """
982

983
984
985
986
987
988
989
    def __init__(self,
                 max_num_batched_tokens: Optional[int],
                 max_num_seqs: int,
                 max_model_len: int,
                 num_lookahead_slots: int = 0,
                 delay_factor: float = 0.0,
                 enable_chunked_prefill: bool = False,
990
991
                 embedding_mode: bool = False,
                 is_multimodal_model: bool = False,
992
                 preemption_mode: Optional[str] = None,
993
                 num_scheduler_steps: int = 1,
994
                 multi_step_stream_outputs: bool = False,
995
996
                 send_delta_data: bool = False,
                 policy: str = "fcfs") -> None:
997
        if max_num_batched_tokens is None:
998
            if enable_chunked_prefill:
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
                if num_scheduler_steps > 1:
                    # Multi-step Chunked-Prefill doesn't allow prompt-chunking
                    # for now. Have max_num_batched_tokens set to max_model_len
                    # so we don't reject sequences on account of a short
                    # max_num_batched_tokens.
                    max_num_batched_tokens = max(max_model_len, 2048)
                else:
                    # It is the values that have the best balance between ITL
                    # and TTFT on A100. Note it is not optimized for throughput.
                    max_num_batched_tokens = 512
1009
1010
1011
            else:
                # If max_model_len is too short, use 2048 as the default value
                # for higher throughput.
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
                max_num_batched_tokens = max(max_model_len, 2048)

            if embedding_mode:
                # For embedding, choose specific value for higher throughput
                max_num_batched_tokens = max(
                    max_num_batched_tokens,
                    _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
                )
            if is_multimodal_model:
                # The value needs to be at least the number of multimodal tokens
                max_num_batched_tokens = max(
                    max_num_batched_tokens,
                    _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
                )

        self.max_num_batched_tokens = max_num_batched_tokens

1029
        if enable_chunked_prefill:
1030
1031
            logger.info(
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
1032
                self.max_num_batched_tokens)
1033

1034
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
1035
        self.max_model_len = max_model_len
1036
1037
        self.num_lookahead_slots = num_lookahead_slots
        self.delay_factor = delay_factor
1038
        self.chunked_prefill_enabled = enable_chunked_prefill
1039
        self.embedding_mode = embedding_mode
1040
        self.preemption_mode = preemption_mode
1041
        self.num_scheduler_steps = num_scheduler_steps
1042
        self.multi_step_stream_outputs = multi_step_stream_outputs
1043
        self.send_delta_data = send_delta_data
1044
        self.policy = policy
1045
1046
1047
        self._verify_args()

    def _verify_args(self) -> None:
1048
1049
        if (self.max_num_batched_tokens < self.max_model_len
                and not self.chunked_prefill_enabled):
1050
1051
1052
1053
1054
1055
1056
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
1057

1058
1059
1060
1061
1062
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
1063

1064
1065
1066
1067
1068
1069
        if self.num_lookahead_slots < 0:
            raise ValueError(
                "num_lookahead_slots "
                f"({self.num_lookahead_slots}) must be greater than or "
                "equal to 0.")

1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
        if self.num_scheduler_steps < 1:
            raise ValueError(
                "num_scheduler_steps "
                f"({self.num_scheduler_steps}) must be greater than or "
                "equal to 1.")

    @property
    def is_multi_step(self) -> bool:
        return self.num_scheduler_steps > 1

1080

1081
class DeviceConfig:
1082
    device: Optional[torch.device]
1083

1084
1085
1086
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
1087
1088
1089
            if current_platform.is_cuda_alike():
                self.device_type = "cuda"
            elif is_neuron():
1090
                self.device_type = "neuron"
1091
1092
            elif is_openvino():
                self.device_type = "openvino"
1093
            elif current_platform.is_tpu():
1094
                self.device_type = "tpu"
1095
            elif current_platform.is_cpu():
1096
                self.device_type = "cpu"
1097
1098
            elif is_xpu():
                self.device_type = "xpu"
1099
            else:
1100
                raise RuntimeError("Failed to infer device type")
1101
1102
1103
1104
1105
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
1106
        if self.device_type in ["neuron", "openvino"]:
1107
            self.device = torch.device("cpu")
1108
1109
        elif self.device_type in ["tpu"]:
            self.device = None
1110
1111
1112
1113
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

1114

1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
class SpeculativeConfig:
    """Configuration for speculative decoding.

    The configuration is currently specialized to draft-model speculative
    decoding with top-1 proposals.
    """

    @staticmethod
    def maybe_create_spec_config(
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        target_dtype: str,
        speculative_model: Optional[str],
1128
        speculative_model_quantization: Optional[str],
1129
        speculative_draft_tensor_parallel_size: Optional[int],
1130
        num_speculative_tokens: Optional[int],
1131
        speculative_disable_mqa_scorer: Optional[bool],
1132
1133
        speculative_max_model_len: Optional[int],
        enable_chunked_prefill: bool,
1134
        disable_log_stats: bool,
1135
        speculative_disable_by_batch_size: Optional[int],
1136
1137
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1138
1139
1140
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: Optional[float],
        typical_acceptance_sampler_posterior_alpha: Optional[float],
1141
        disable_logprobs: Optional[bool],
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
    ) -> Optional["SpeculativeConfig"]:
        """Create a SpeculativeConfig if possible, else return None.

        This function attempts to create a SpeculativeConfig object based on the
        provided parameters. If the necessary conditions are met, it returns an
        instance of SpeculativeConfig. Otherwise, it returns None.

        Args:
            target_model_config (ModelConfig): The configuration of the target
                model.
            target_parallel_config (ParallelConfig): The parallel configuration
                for the target model.
            target_dtype (str): The data type used for the target model.
            speculative_model (Optional[str]): The name of the speculative
                model, if provided.
1157
1158
1159
            speculative_model_quantization (Optional[str]): Quantization method
                that was used to quantize the speculative model weights. If
                None, we assume the model weights are not quantized.
1160
1161
            speculative_draft_tensor_parallel_size (Optional[int]): The degree
                of the tensor parallelism for the draft model.
1162
            num_speculative_tokens (Optional[int]): The number of speculative
1163
1164
                tokens, if provided. Will default to the number in the draft
                model config if present, otherwise is required.
1165
1166
1167
            speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
                scorer for the speculative model and fall back to batch
                expansion for scoring.
1168
1169
1170
1171
1172
1173
            speculative_max_model_len (Optional[int]): The maximum model len of
                the speculative model. Used when testing the ability to skip
                speculation for some sequences.
            enable_chunked_prefill (bool): Whether vLLM is configured to use
                chunked prefill or not. Used for raising an error since its not
                yet compatible with spec decode.
1174
1175
1176
            speculative_disable_by_batch_size (Optional[int]): Disable
                speculative decoding for new incoming requests when the number
                of enqueue requests  is larger than this value, if provided.
1177
1178
1179
1180
            ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
                window, if provided.
            ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
                window, if provided.
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
                accepted. This threshold is used only when we use the 
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1194
1195
1196
1197
1198
            disable_logprobs (Optional[bool]): If set to True, token log
                probabilities are not returned during speculative decoding.
                If set to False, token log probabilities are returned
                according to the log probability settings in SamplingParams.
                If not specified, it defaults to True.
1199
    
1200
1201
1202
1203
1204
        Returns:
            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                the necessary conditions are met, else None.
        """

1205
1206
1207
1208
        if speculative_model is None:
            if num_speculative_tokens is not None:
                raise ValueError("num_speculative_tokens was provided without "
                                 "speculative_model.")
1209
1210
            return None

1211
1212
1213
1214
1215
1216
        if (speculative_disable_by_batch_size is not None
                and speculative_disable_by_batch_size < 2):
            raise ValueError("Expect the batch size threshold of disabling "
                             "speculative decoding is > 1, but got "
                             f"{speculative_disable_by_batch_size=}")

1217
1218
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
1219
1220
1221
1222
1223
        if enable_chunked_prefill:
            raise ValueError(
                "Speculative decoding and chunked prefill are "
                f"currently mutually exclusive ({enable_chunked_prefill=}).")

1224
1225
        # TODO: The user should be able to specify revision/max model len
        # for the draft model. It is not currently supported.
1226
1227
        draft_revision = None
        draft_code_revision = None
1228
        draft_quantization = speculative_model_quantization
1229

1230
1231
        if speculative_model == "[ngram]":
            if ngram_prompt_lookup_min is None:
1232
1233
1234
1235
1236
1237
1238
1239
                ngram_prompt_lookup_min = 1
            if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
                raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
            if ngram_prompt_lookup_min < 1:
                raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
            if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
                raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
                                 f"larger than {ngram_prompt_lookup_max=}")
1240

1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            draft_model_config = target_model_config
            draft_parallel_config = target_parallel_config
        else:
            ngram_prompt_lookup_max = 0
            ngram_prompt_lookup_min = 0
            draft_model_config = ModelConfig(
                model=speculative_model,
                tokenizer=target_model_config.tokenizer,
                tokenizer_mode=target_model_config.tokenizer_mode,
                trust_remote_code=target_model_config.trust_remote_code,
                dtype=target_model_config.dtype,
                seed=target_model_config.seed,
                revision=draft_revision,
                code_revision=draft_code_revision,
                tokenizer_revision=target_model_config.tokenizer_revision,
                max_model_len=None,
1260
                spec_target_max_model_len=target_model_config.max_model_len,
1261
1262
                quantization=draft_quantization,
                enforce_eager=target_model_config.enforce_eager,
1263
1264
                max_seq_len_to_capture=target_model_config.
                max_seq_len_to_capture,
1265
1266
1267
                max_logprobs=target_model_config.max_logprobs,
            )

1268
            draft_hf_config = draft_model_config.hf_config
1269

1270
1271
1272
1273
1274
            if (num_speculative_tokens is not None
                    and hasattr(draft_hf_config, "num_lookahead_tokens")):
                draft_hf_config.num_lookahead_tokens = num_speculative_tokens

            n_predict = getattr(draft_hf_config, "n_predict", None)
1275
1276
1277
1278
1279
1280
1281
1282
            if n_predict is not None:
                if num_speculative_tokens is None:
                    # Default to max value defined in draft model config.
                    num_speculative_tokens = n_predict
                elif num_speculative_tokens > n_predict:
                    # Verify provided value doesn't exceed the maximum
                    # supported by the draft model.
                    raise ValueError(
1283
1284
1285
                        "This speculative model supports a maximum of "
                        f"num_speculative_tokens={n_predict}, but "
                        f"{num_speculative_tokens=} was provided.")
1286

1287
1288
1289
1290
1291
1292
1293
1294
1295
            draft_model_config.max_model_len = (
                SpeculativeConfig._maybe_override_draft_max_model_len(
                    speculative_max_model_len,
                    draft_model_config.max_model_len,
                    target_model_config.max_model_len,
                ))

            draft_parallel_config = (
                SpeculativeConfig.create_draft_parallel_config(
1296
                    target_parallel_config,
1297
                    speculative_draft_tensor_parallel_size, draft_hf_config))
1298

1299
1300
1301
1302
1303
1304
        if num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative_model unless the draft model config contains an "
                "n_predict parameter.")

1305
1306
1307
1308
        if typical_acceptance_sampler_posterior_threshold is None:
            typical_acceptance_sampler_posterior_threshold = 0.09
        if typical_acceptance_sampler_posterior_alpha is None:
            typical_acceptance_sampler_posterior_alpha = 0.3
1309
1310
        if disable_logprobs is None:
            disable_logprobs = True
1311

1312
1313
1314
1315
        return SpeculativeConfig(
            draft_model_config,
            draft_parallel_config,
            num_speculative_tokens,
1316
            speculative_disable_mqa_scorer,
1317
            speculative_disable_by_batch_size,
1318
1319
            ngram_prompt_lookup_max,
            ngram_prompt_lookup_min,
1320
1321
1322
1323
1324
            draft_token_acceptance_method=draft_token_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=\
                typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=\
                typical_acceptance_sampler_posterior_alpha,
1325
1326
            disable_logprobs=disable_logprobs,
            disable_log_stats=disable_log_stats,
1327
1328
        )

1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
    @staticmethod
    def _maybe_override_draft_max_model_len(
        speculative_max_model_len: Optional[int],
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:

            if speculative_max_model_len > draft_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {draft_max_model_len=}")

            if speculative_max_model_len > target_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {target_max_model_len=}")

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

1364
1365
    @staticmethod
    def create_draft_parallel_config(
1366
        target_parallel_config: ParallelConfig,
1367
1368
        speculative_draft_tensor_parallel_size: Optional[int],
        draft_hf_config: PretrainedConfig,
1369
    ) -> ParallelConfig:
1370
1371
        """Create a parallel config for use by the draft worker.

1372
        This is mostly a copy of the target parallel config, except the tp_size.
1373
        """
1374
        if speculative_draft_tensor_parallel_size is None:
1375
1376
1377
1378
1379
1380
1381
1382
1383
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "MLPSpeculator cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1")
            else:
                speculative_draft_tensor_parallel_size = \
                    target_parallel_config.tensor_parallel_size
1384
1385
1386
        elif speculative_draft_tensor_parallel_size != 1:
            # TODO(wooyeon): allow tp values larger than 1
            raise ValueError(
1387
                f"{speculative_draft_tensor_parallel_size=} cannot be "
1388
1389
                f"other value than 1")

1390
1391
1392
        draft_parallel_config = ParallelConfig(
            pipeline_parallel_size=target_parallel_config.
            pipeline_parallel_size,
1393
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
1394
1395
            distributed_executor_backend=target_parallel_config.
            distributed_executor_backend,
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
            max_parallel_loading_workers=target_parallel_config.
            max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.
            disable_custom_all_reduce,
            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
            ray_workers_use_nsight=target_parallel_config.
            ray_workers_use_nsight,
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

    def __init__(
        self,
        draft_model_config: ModelConfig,
        draft_parallel_config: ParallelConfig,
        num_speculative_tokens: int,
1413
        speculative_disable_mqa_scorer: Optional[bool],
1414
1415
1416
        speculative_disable_by_batch_size: Optional[int],
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1417
1418
1419
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
1420
        disable_logprobs: bool,
1421
        disable_log_stats: bool,
1422
1423
1424
1425
1426
1427
1428
1429
    ):
        """Create a SpeculativeConfig object.

        Args:
            draft_model_config: ModelConfig for the draft model.
            draft_parallel_config: ParallelConfig for the draft model.
            num_speculative_tokens: The number of tokens to sample from the
                draft model before scoring with the target model.
1430
1431
1432
1433
1434
            speculative_disable_by_batch_size: Disable speculative
                decoding for new incoming requests when the number of
                enqueue requests is larger than this value.
            ngram_prompt_lookup_max: Max size of ngram token window.
            ngram_prompt_lookup_min: Min size of ngram token window.
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
                accepted. This threshold is used only when we use the 
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1448
1449
1450
1451
1452
1453
            disable_logprobs: If set to True, token log probabilities will not
                be returned even if requested by sampling parameters. This 
                reduces latency by skipping logprob calculation in proposal
                sampling, target sampling, and after accepted tokens are
                determined. If set to False, log probabilities will be
                returned.
1454
1455
            disable_log_stats: Whether to disable periodic printing of stage
                times in speculative decoding.
1456
1457
1458
1459
        """
        self.draft_model_config = draft_model_config
        self.draft_parallel_config = draft_parallel_config
        self.num_speculative_tokens = num_speculative_tokens
1460
        self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
1461
1462
1463
1464
        self.speculative_disable_by_batch_size = \
            speculative_disable_by_batch_size
        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
1465
1466
1467
1468
1469
        self.draft_token_acceptance_method = draft_token_acceptance_method
        self.typical_acceptance_sampler_posterior_threshold = \
            typical_acceptance_sampler_posterior_threshold
        self.typical_acceptance_sampler_posterior_alpha = \
            typical_acceptance_sampler_posterior_alpha
1470
        self.disable_logprobs = disable_logprobs
1471
        self.disable_log_stats = disable_log_stats
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482

        self._verify_args()

    def _verify_args(self) -> None:
        if self.num_speculative_tokens <= 0:
            raise ValueError("Expected num_speculative_tokens to be greater "
                             f"than zero ({self.num_speculative_tokens}).")

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
                self.draft_parallel_config)
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
            # Validate and set draft token acceptance related settings.

        if (self.draft_token_acceptance_method is None):
            raise ValueError("draft_token_acceptance_method is not set. "
                             "Expected values are rejection_sampler or "
                             "typical_acceptance_sampler.")

        if (self.draft_token_acceptance_method != 'rejection_sampler'
                and self.draft_token_acceptance_method !=
                'typical_acceptance_sampler'):
            raise ValueError(
                "Expected draft_token_acceptance_method to be either "
                "rejection_sampler or typical_acceptance_sampler. Instead it "
                f"is {self.draft_token_acceptance_method}")

        if (self.typical_acceptance_sampler_posterior_threshold < 0
                or self.typical_acceptance_sampler_posterior_alpha < 0):
            raise ValueError(
                "Expected typical_acceptance_sampler_posterior_threshold "
                "and typical_acceptance_sampler_posterior_alpha to be > 0. "
                "Instead found "
                f"typical_acceptance_sampler_posterior_threshold = "
                f"{self.typical_acceptance_sampler_posterior_threshold} and "
                f"typical_acceptance_sampler_posterior_alpha = "
                f"{self.typical_acceptance_sampler_posterior_alpha}")
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def __repr__(self) -> str:
1520
1521
1522
1523
        if self.ngram_prompt_lookup_max > 0:
            draft_model = "[ngram]"
        else:
            draft_model = self.draft_model_config.model
1524
1525
1526
1527
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


1528
1529
1530
1531
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
1532
    fully_sharded_loras: bool = False
1533
    max_cpu_loras: Optional[int] = None
1534
    lora_dtype: Optional[Union[torch.dtype, str]] = None
1535
1536
1537
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256
1538
    long_lora_scaling_factors: Optional[Tuple[float]] = None
1539
1540

    def __post_init__(self):
1541
1542
1543
        # Setting the maximum rank to 256 should be able to satisfy the vast
        # majority of applications.
        possible_max_ranks = (8, 16, 32, 64, 128, 256)
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
1560
                f"max_loras ({self.max_loras})")
1561
1562
1563
1564
1565
1566

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
1567
1568
1569
        if model_config.quantization and model_config.quantization not in [
                "awq", "gptq"
        ]:
1570
            # TODO support marlin
1571
1572
            logger.warning("%s quantization is not tested with LoRA yet.",
                           model_config.quantization)
1573
1574

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
1575
1576
        # Reminder: Please update docs/source/serving/compatibility_matrix.rst
        # If the feature combo become valid
1577
1578
        if scheduler_config.chunked_prefill_enabled:
            raise ValueError("LoRA is not supported with chunked prefill yet.")
1579
1580


1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
@dataclass
class PromptAdapterConfig:
    max_prompt_adapters: int
    max_prompt_adapter_token: int
    max_cpu_prompt_adapters: Optional[int] = None
    prompt_adapter_dtype: Optional[torch.dtype] = None

    def __post_init__(self):

        if self.max_prompt_adapters < 1:
            raise ValueError(f"max_prompt_adapters "
                             f"({self.max_prompt_adapters}) must be >= 1.")
        if self.max_prompt_adapter_token == 0:
            raise ValueError("max_prompt_adapter_token must be set.")
        if self.max_cpu_prompt_adapters is None:
            self.max_cpu_prompt_adapters = self.max_prompt_adapters

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.prompt_adapter_dtype in (None, "auto"):
            self.prompt_adapter_dtype = model_config.dtype
        elif isinstance(self.prompt_adapter_dtype, str):
            self.prompt_adapter_dtype = getattr(torch,
                                                self.prompt_adapter_dtype)


1606
@dataclass
1607
class MultiModalConfig:
1608
1609
    """Controls the behavior of multimodal models."""

1610
    limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
1611
1612
1613
1614
1615
    """
    The maximum number of multi-modal input instances allowed per prompt
    for each :class:`~vllm.multimodal.MultiModalPlugin`.
    """

1616
    # TODO: Add configs to init vision tower or not.
1617

1618

1619
1620
1621
1622
1623
1624
1625
1626
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

1627
_ROCM_NOT_SUPPORTED_DTYPE: List[str] = []  #
1628

1629
1630
1631

def _get_and_verify_dtype(
    config: PretrainedConfig,
1632
    dtype: Union[str, torch.dtype],
1633
1634
1635
1636
1637
1638
1639
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

1640
1641
1642
1643
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
Woosuk Kwon's avatar
Woosuk Kwon committed
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
                if config.model_type == "gemma2":
                    logger.info(
                        "For Gemma 2, we downcast float32 to bfloat16 instead "
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16.")
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
1654
1655
            else:
                torch_dtype = config_dtype
1656
        else:
1657
1658
1659
1660
1661
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
1662
    else:
1663
        raise ValueError(f"Unknown dtype: {dtype}")
1664
1665
1666
1667
1668

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
1669
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
1670
1671
1672
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
1673
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
1674
1675
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
1676
            # Casting between float16 and bfloat16 is allowed with a warning.
1677
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
1678
1679

    return torch_dtype
1680
1681
1682
1683
1684


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
1685
    disable_sliding_window: bool,
1686
    sliding_window_len: Optional[Union[int, List[Optional[int]]]],
1687
    spec_target_max_model_len: Optional[int] = None,
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
1698
1699
        # ChatGLM2
        "seq_length",
1700
1701
        # Command-R
        "model_max_length",
1702
1703
1704
1705
1706
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
1707
    # Choose the smallest "max_length" from the possible keys.
1708
    max_len_key = None
1709
    for key in possible_keys:
1710
1711
1712
1713
1714
        max_len = getattr(hf_config, key, None)
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
1715
1716
1717
1718

    # If sliding window is manually disabled, max_length should be less
    # than the sliding window length in the model config.
    if disable_sliding_window and sliding_window_len is not None:
1719
1720

        sliding_window_len_min = get_min_sliding_window(sliding_window_len)
1721
        max_len_key = "sliding_window" \
1722
1723
1724
            if sliding_window_len_min < derived_max_model_len else max_len_key
        derived_max_model_len = min(derived_max_model_len,
                                    sliding_window_len_min)
1725
1726
1727

    # If none of the keys were found in the config, use a default and
    # log a warning.
1728
    if derived_max_model_len == float("inf"):
1729
1730
1731
1732
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

1733
1734
1735
1736
1737
        if spec_target_max_model_len is not None:
            # If this is a speculative draft model, we use the max model len
            # from the target model.
            return spec_target_max_model_len

1738
1739
1740
1741
        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
1742
            "%s. Assuming the model's maximum length is %d.", possible_keys,
1743
            default_max_len)
1744
        derived_max_model_len = default_max_len
1745

1746
    rope_scaling = getattr(hf_config, "rope_scaling", None)
1747
    if rope_scaling is not None:
1748
1749
1750
        # No need to consider "type" key because of patch_rope_scaling when
        # loading HF config
        rope_type = rope_scaling["rope_type"]
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760

        if rope_type not in ("su", "longrope", "llama3"):
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that supports rope_scaling
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "with rope_scaling. Please raise an issue so we can "
                    "investigate.")

1761
1762
1763
1764
            # NOTE: rope_type == "default" does not define factor
            # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
            scaling_factor = rope_scaling.get("factor", 1.0)

1765
1766
1767
1768
            if rope_type == "yarn":
                derived_max_model_len = rope_scaling[
                    "original_max_position_embeddings"]
            derived_max_model_len *= scaling_factor
1769

1770
1771
    # If the user specified a max length, make sure it is smaller than the
    # derived length from the HF model config.
1772
    if max_model_len is None:
1773
        max_model_len = int(derived_max_model_len)
1774
    elif max_model_len > derived_max_model_len:
1775
1776
1777
1778
1779
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, "model_max_length", None)
        if model_max_length is not None and max_model_len <= model_max_length:
1780
1781
1782
1783
1784
1785
1786
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that has model_max_length
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "model_max_length in the config. Please raise an issue "
                    "so we can investigate.")
1787
        else:
1788
            msg = (
1789
                f"User-specified max_model_len ({max_model_len}) is greater "
1790
1791
                f"than the derived max_model_len ({max_len_key}="
                f"{derived_max_model_len} or model_max_length="
1792
                f"{model_max_length} in model's config.json). This may lead "
1793
1794
1795
1796
1797
1798
1799
1800
1801
                "to incorrect model outputs or CUDA errors.")
            if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
                logger.warning(
                    "%s Make sure the value is correct and within the "
                    "model context size.", msg)
            else:
                raise ValueError(
                    f"{msg} To allow overriding this maximum, set "
                    "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
1802
    return int(max_model_len)
1803
1804


1805
1806
1807
1808
1809
1810
1811
1812
def get_min_sliding_window(
        sliding_window: Union[int, List[Optional[int]]]) -> int:
    if isinstance(sliding_window, list):
        return min(s for s in sliding_window if s is not None)

    return sliding_window


1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
def get_served_model_name(model: str,
                          served_model_name: Optional[Union[str, List[str]]]):
    """
    If the input is a non-empty list, the first model_name in 
    `served_model_name` is taken. 
    If the input is a non-empty string, it is used directly. 
    For cases where the input is either an empty string or an 
    empty list, the fallback is to use `self.model`.
    """
    if not served_model_name:
        return model
    if isinstance(served_model_name, list):
        return served_model_name[0]
    return served_model_name


1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
@dataclass
class DecodingConfig:
    """Dataclass which contains the decoding strategy of the engine"""

    # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
    guided_decoding_backend: str = 'outlines'

    def __post_init__(self):
        valid_guided_backends = ['outlines', 'lm-format-enforcer']
        backend = self.guided_decoding_backend
        if backend not in valid_guided_backends:
            raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                             f"must be one of {valid_guided_backends}")


1844
1845
1846
1847
1848
@dataclass
class ObservabilityConfig:
    """Configuration for observability."""
    otlp_traces_endpoint: Optional[str] = None

1849
1850
1851
1852
1853
1854
1855
1856
    # Collecting detailed timing information for each request can be expensive.

    # If set, collects the model forward time for the request.
    collect_model_forward_time: bool = False

    # If set, collects the model execute time for the request.
    collect_model_execute_time: bool = False

1857
    def __post_init__(self):
1858
1859
1860
1861
1862
        if not is_otel_available() and self.otlp_traces_endpoint is not None:
            raise ValueError(
                "OpenTelemetry is not available. Unable to configure "
                "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
                f"installed. Original error:\n{otel_import_error_traceback}")
1863

1864
1865
1866
1867
1868
1869
        if ((self.collect_model_forward_time
             or self.collect_model_execute_time)
                and self.otlp_traces_endpoint is None):
            raise ValueError(
                "collect_model_forward_time or collect_model_execute_time "
                "requires --otlp-traces-endpoint to be set.")
1870
1871


1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
@dataclass(frozen=True)
class EngineConfig:
    """Dataclass which contains all engine-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    model_config: ModelConfig
    cache_config: CacheConfig
    parallel_config: ParallelConfig
    scheduler_config: SchedulerConfig
    device_config: DeviceConfig
1883
    load_config: LoadConfig
1884
1885
    lora_config: Optional[LoRAConfig]
    speculative_config: Optional[SpeculativeConfig]
1886
    decoding_config: Optional[DecodingConfig]
1887
    observability_config: Optional[ObservabilityConfig]
1888
    prompt_adapter_config: Optional[PromptAdapterConfig]
1889
1890
1891
1892

    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
1893
1894
1895
        self.model_config.verify_async_output_proc(self.parallel_config,
                                                   self.speculative_config,
                                                   self.device_config)
1896
1897
1898
1899
1900
1901
1902
        self.model_config.verify_with_parallel_config(self.parallel_config)
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
1903
1904
1905
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
1906
1907
1908
1909
1910
1911

    def to_dict(self):
        """Return the configs as a dictionary, for use in **kwargs.
        """
        return dict(
            (field.name, getattr(self, field.name)) for field in fields(self))