config.py 83.5 KB
Newer Older
1
import enum
2
import json
3
from dataclasses import dataclass, field, fields
4
5
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
                    Optional, Tuple, Type, Union)
6
7

import torch
8
from transformers import PretrainedConfig
9

10
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
from vllm.model_executor.models import ModelRegistry
14
from vllm.platforms import current_platform
15
from vllm.tracing import is_otel_available, otel_import_error_traceback
16
from vllm.transformers_utils.config import (ConfigFormat, get_config,
17
18
                                            get_hf_image_processor_config,
                                            get_hf_text_config)
19
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
20
                        is_hip, is_neuron, is_openvino, is_xpu,
21
                        print_warning_once)
22

23
24
25
if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

26
    from vllm.executor.executor_base import ExecutorBase
27
    from vllm.model_executor.model_loader.loader import BaseModelLoader
28
29
    from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
        BaseTokenizerGroup)
30

31
32
logger = init_logger(__name__)

33
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
34
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
35

36
37
_PP_SUPPORTED_MODELS = [
    "AquilaForCausalLM",
38
    "AquilaModel",
39
    "DeepseekV2ForCausalLM",
40
41
    "GPT2LMHeadModel",
    "InternLM2ForCausalLM",
42
    "InternLMForCausalLM",
43
    "InternVLChatModel",
44
    "JAISLMHeadModel",
45
46
47
    "LlamaForCausalLM",
    "LLaMAForCausalLM",
    "MistralForCausalLM",
48
    "MixtralForCausalLM",
49
    "NemotronForCausalLM",
50
    "Phi3ForCausalLM",
51
52
    "Qwen2ForCausalLM",
    "Qwen2MoeForCausalLM",
53
    "QWenLMHeadModel",
54
    "Qwen2VLForConditionalGeneration",
55
56
]

57
58

class ModelConfig:
59
60
61
62
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
63
64
            It is also used as the content for `model_name` tag in metrics 
            output when `served_model_name` is not specified. 
65
        tokenizer: Name or path of the huggingface tokenizer to use.
66
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
67
68
            available, "slow" will always use the slow tokenizer, and
            "mistral" will always use the tokenizer from `mistral_common`.
69
70
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
71
72
73
74
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
75
76
77
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
78
        code_revision: The specific revision to use for the model code on
79
            Hugging Face Hub. It can be a branch name, a tag name, or a
80
            commit id. If unspecified, will use the default version.
81
82
83
        rope_scaling: Dictionary containing the scaling configuration for the
            RoPE embeddings. When using this flag, don't update
            `max_position_embeddings` to the expected new maximum.
84
85
86
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
87
88
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
89
90
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
91
92
        quantization_param_path: Path to JSON file containing scaling factors.
            Used to load KV cache scaling factors into the model when KV cache
93
94
            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
            be used to load activation and weight scaling factors when the
95
            model dtype is FP8_E4M3 on ROCm.
96
97
98
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
99
            If None, the user did not specify, so default to False.
100
101
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
102
103
104
            to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
105
106
107
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
108
109
110
111
        disable_sliding_window: Whether to disable sliding window. If True,
            we will disable the sliding window functionality of the model.
            If the model does not support sliding window, this argument is
            ignored.
112
113
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer.
114
115
116
117
        served_model_name: The model name used in metrics tag `model_name`,
            matches the model name exposed via the APIs. If multiple model 
            names provided, the first name will be used. If not specified, 
            the model name will be the same as `model`.
118
119
        limit_mm_per_prompt: Maximum number of data instances per modality 
            per prompt. Only applicable for multimodal models.
120
121
122
123
        override_neuron_config: Initialize non default neuron config or 
            override default neuron config that are specific to Neuron devices, 
            this argument will be used to configure the neuron config that 
            can not be gathered from the vllm arguments. 
124
125
        config_format: The config format which shall be loaded.
            Defaults to 'auto' which defaults to 'hf'.
126
127
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor.
128
    """
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    def __init__(self,
                 model: str,
                 tokenizer: str,
                 tokenizer_mode: str,
                 trust_remote_code: bool,
                 dtype: Union[str, torch.dtype],
                 seed: int,
                 revision: Optional[str] = None,
                 code_revision: Optional[str] = None,
                 rope_scaling: Optional[dict] = None,
                 rope_theta: Optional[float] = None,
                 tokenizer_revision: Optional[str] = None,
                 max_model_len: Optional[int] = None,
                 spec_target_max_model_len: Optional[int] = None,
                 quantization: Optional[str] = None,
                 quantization_param_path: Optional[str] = None,
                 enforce_eager: Optional[bool] = None,
                 max_context_len_to_capture: Optional[int] = None,
                 max_seq_len_to_capture: Optional[int] = None,
                 max_logprobs: int = 20,
                 disable_sliding_window: bool = False,
                 skip_tokenizer_init: bool = False,
                 served_model_name: Optional[Union[str, List[str]]] = None,
                 limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
                 use_async_output_proc: bool = True,
                 override_neuron_config: Optional[Dict[str, Any]] = None,
156
157
                 config_format: ConfigFormat = ConfigFormat.AUTO,
                 mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
158
        self.model = model
159
        self.tokenizer = tokenizer
160
        self.tokenizer_mode = tokenizer_mode
161
        self.trust_remote_code = trust_remote_code
162
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
163
        self.revision = revision
164
        self.code_revision = code_revision
165
        self.rope_scaling = rope_scaling
166
        self.rope_theta = rope_theta
167
168
169
170
171
        # The tokenizer version is consistent with the model version by default.
        if tokenizer_revision is None:
            self.tokenizer_revision = revision
        else:
            self.tokenizer_revision = tokenizer_revision
172
        self.quantization = quantization
173
        self.quantization_param_path = quantization_param_path
174
        self.enforce_eager = enforce_eager
175
        if max_context_len_to_capture is not None:
176
177
            raise ValueError("`max_context_len_to_capture` is deprecated. "
                             "Use `max_seq_len_to_capture` instead.")
178
        self.max_seq_len_to_capture = max_seq_len_to_capture
179
        self.max_logprobs = max_logprobs
180
        self.disable_sliding_window = disable_sliding_window
181
        self.skip_tokenizer_init = skip_tokenizer_init
182

183
        self.hf_config = get_config(self.model, trust_remote_code, revision,
184
185
                                    code_revision, rope_scaling, rope_theta,
                                    config_format)
186
        self.hf_text_config = get_hf_text_config(self.hf_config)
187
188
        self.hf_image_processor_config = get_hf_image_processor_config(
            self.model, revision)
189
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
190
        self.use_async_output_proc = use_async_output_proc
191
        self.mm_processor_kwargs = mm_processor_kwargs
Woosuk Kwon's avatar
Woosuk Kwon committed
192

193
194
        # Set enforce_eager to False if the value is unset.
        if self.enforce_eager is None:
195
196
            self.enforce_eager = False

Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199
200
201
202
203
204
205
206
        if (not self.disable_sliding_window
                and self.hf_text_config.model_type == "gemma2"
                and self.hf_text_config.sliding_window is not None):
            print_warning_once(
                "Gemma 2 uses sliding window attention for every odd layer, "
                "which is currently not supported by vLLM. Disabling sliding "
                "window and capping the max length to the sliding window size "
                f"({self.hf_text_config.sliding_window}).")
            self.disable_sliding_window = True

207
208
209
210
        self.max_model_len = _get_and_verify_max_len(
            hf_config=self.hf_text_config,
            max_model_len=max_model_len,
            disable_sliding_window=self.disable_sliding_window,
211
212
            sliding_window_len=self.get_hf_config_sliding_window(),
            spec_target_max_model_len=spec_target_max_model_len)
213
214
        self.served_model_name = get_served_model_name(model,
                                                       served_model_name)
215
216
        self.multimodal_config = self._init_multimodal_config(
            limit_mm_per_prompt)
217
218
        if not self.skip_tokenizer_init:
            self._verify_tokenizer_mode()
219
220
221

        self.override_neuron_config = override_neuron_config if is_neuron(
        ) else None
222
        self._verify_embedding_mode()
223
        self._verify_quantization()
224
        self._verify_cuda_graph()
225
        self._verify_bnb_config()
226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    def _init_multimodal_config(
        self, limit_mm_per_prompt: Optional[Mapping[str, int]]
    ) -> Optional["MultiModalConfig"]:
        architectures = getattr(self.hf_config, "architectures", [])
        if any(
                ModelRegistry.is_multimodal_model(arch)
                for arch in architectures):
            return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
        else:
            if limit_mm_per_prompt:
                raise ValueError(
                    "limit_mm_per_prompt is only supported for multimodal "
                    "models.")
            return None

242
243
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
244
        if tokenizer_mode not in ["auto", "slow", "mistral"]:
245
246
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
247
                "either 'auto', 'slow' or 'mistral'.")
248
        self.tokenizer_mode = tokenizer_mode
249

250
251
252
253
254
    def _verify_embedding_mode(self) -> None:
        architectures = getattr(self.hf_config, "architectures", [])
        self.embedding_mode = any(
            ModelRegistry.is_embedding_model(arch) for arch in architectures)

255
256
257
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
258
            # compressed-tensors uses a "compression_config" key
259
            quant_cfg = getattr(self.hf_config, "compression_config", None)
260
261
        return quant_cfg

262
    def _verify_quantization(self) -> None:
263
        supported_quantization = [*QUANTIZATION_METHODS]
264
265
266
267
        rocm_supported_quantization = [
            "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
            "fbgemm_fp8"
        ]
268
        optimized_quantization_methods = [
269
270
271
            "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
            "awq_marlin", "fbgemm_fp8", "compressed_tensors",
            "compressed-tensors", "experts_int8"
272
        ]
273
        tpu_supported_quantization = ["tpu_int8"]
274
        neuron_supported_quantization = ["neuron_quant"]
275
276
277
278
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
279
280
        quant_cfg = self._parse_quant_hf_config()

281
282
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
283
284

            # Detect which checkpoint is it
285
            for _, method in QUANTIZATION_METHODS.items():
286
287
288
289
290
291
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization)
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break
292

293
            # Verify quantization configurations.
294
            if self.quantization is None:
295
296
                self.quantization = quant_method
            elif self.quantization != quant_method:
297
298
                raise ValueError(
                    "Quantization method specified in the model config "
299
                    f"({quant_method}) does not match the quantization "
300
301
302
303
304
305
306
307
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
308
            if is_hip(
309
            ) and self.quantization not in rocm_supported_quantization:
310
                raise ValueError(
311
312
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm.")
313
            if current_platform.is_tpu(
314
315
316
317
            ) and self.quantization not in tpu_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in TPU Backend.")
318
            if self.quantization not in optimized_quantization_methods:
319
                logger.warning(
320
                    "%s quantization is not fully "
321
                    "optimized yet. The speed can be slower than "
322
                    "non-quantized models.", self.quantization)
323
324
325
326
327
328
            if (self.quantization == "awq" and is_hip()
                    and not envs.VLLM_USE_TRITON_AWQ):
                logger.warning(
                    "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
                    " is not set, enabling VLLM_USE_TRITON_AWQ.")
                envs.VLLM_USE_TRITON_AWQ = True
329
330
331
332
333
            if is_neuron(
            ) and self.quantization not in neuron_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in Neuron Backend.")
334

335
    def _verify_cuda_graph(self) -> None:
336
337
338
339
        if self.max_seq_len_to_capture is None:
            self.max_seq_len_to_capture = self.max_model_len
        self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
                                          self.max_model_len)
340

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    def _verify_bnb_config(self) -> None:
        """
        The current version of bitsandbytes (0.44.0) with 8-bit models does not 
        yet support CUDA graph.
        """
        is_bitsandbytes = self.quantization == "bitsandbytes"
        has_quantization_config = (getattr(self.hf_config,
                                           "quantization_config", None)
                                   is not None)
        is_8bit = (self.hf_config.quantization_config.get(
            "load_in_8bit", False) if has_quantization_config else False)
        if all([
                is_bitsandbytes,
                has_quantization_config,
                is_8bit,
                not self.enforce_eager,
        ]):
            logger.warning(
                "CUDA graph is not supported on BitAndBytes 8bit yet, "
                "fallback to the eager mode.")
            self.enforce_eager = True

363
364
365
366
367
368
369
370
371
372
373
374
    def verify_async_output_proc(self, parallel_config, speculative_config,
                                 device_config) -> None:
        if not self.use_async_output_proc:
            # Nothing to check
            return

        if parallel_config.pipeline_parallel_size > 1:
            logger.warning("Async output processing can not be enabled "
                           "with pipeline parallel")
            self.use_async_output_proc = False
            return

375
        if device_config.device_type not in ("cuda", "tpu"):
376
            logger.warning(
377
378
                "Async output processing is only supported for CUDA or TPU. "
                "Disabling it for other platforms.")
379
380
381
382
383
384
385
386
387
            self.use_async_output_proc = False
            return

        if envs.VLLM_USE_RAY_SPMD_WORKER:
            logger.warning(
                "Async output processing can not be enabled with ray spmd")
            self.use_async_output_proc = False
            return

388
        if device_config.device_type == "cuda" and self.enforce_eager:
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            self.use_async_output_proc = not self.enforce_eager
            return

        # Async postprocessor is not necessary with embedding mode
        # since there is no token generation
        if self.embedding_mode:
            self.use_async_output_proc = False

        if speculative_config:
            logger.warning("Async output processing is not supported with"
                           " speculative decoding currently.")
            self.use_async_output_proc = False

406
407
408
409
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
410
411
        total_num_attention_heads = getattr(self.hf_text_config,
                                            "num_attention_heads", 0)
412
413
414
415
416
417
418
419
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        pipeline_parallel_size = parallel_config.pipeline_parallel_size
420
421
422
423
424
425
426
        architectures = getattr(self.hf_config, "architectures", [])
        if not all(arch in _PP_SUPPORTED_MODELS
                   for arch in architectures) and pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported for the following "
                f" architectures: {_PP_SUPPORTED_MODELS}.")

427
428
429
430
431
        if pipeline_parallel_size > 1 and self.use_async_output_proc:
            logger.warning("Async output processor is not supported with "
                           "pipeline parallelism currently. Disabling it.")
            self.use_async_output_proc = False

432
    def get_hf_config_sliding_window(self) -> Optional[int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
433
        """Get the sliding window size, or None if disabled."""
434
435
436
437

        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
        # addition to sliding window size. We check if that field is present
        # and if it's False, return None.
438
439
        if (hasattr(self.hf_text_config, "use_sliding_window")
                and not self.hf_text_config.use_sliding_window):
440
            return None
441
        return getattr(self.hf_text_config, "sliding_window", None)
442

443
444
445
446
447
448
449
450
451
    def get_sliding_window(self) -> Optional[int]:
        """Get the sliding window size, or None if disabled.
        """
        # If user disables sliding window, return None.
        if self.disable_sliding_window:
            return None
        # Otherwise get the value from the hf config.
        return self.get_hf_config_sliding_window()

452
    def get_vocab_size(self) -> int:
453
        return self.hf_text_config.vocab_size
454

455
    def get_hidden_size(self) -> int:
456
        return self.hf_text_config.hidden_size
457
458

    def get_head_size(self) -> int:
wangding zeng's avatar
wangding zeng committed
459
460
461
462
463
464
        # TODO remove hard code
        if hasattr(self.hf_text_config, "model_type"
                   ) and self.hf_text_config.model_type == 'deepseek_v2':
            # FlashAttention supports only head_size 32, 64, 128, 256,
            # we need to pad head_size 192 to 256
            return 256
465
466
        if hasattr(self.hf_text_config, "head_dim"):
            return self.hf_text_config.head_dim
467
        # FIXME(woosuk): This may not be true for all models.
468
469
        return (self.hf_text_config.hidden_size //
                self.hf_text_config.num_attention_heads)
470

471
472
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
473
        # For GPTBigCode & Falcon:
474
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
475
476
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
477
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
478
        new_decoder_arch_falcon = (
479
            self.hf_config.model_type in falcon_model_types
480
            and getattr(self.hf_config, "new_decoder_architecture", False))
481
        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
482
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
483
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
484
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
485
            return 1
486

487
        # For DBRX and MPT
488
489
490
491
492
        if self.hf_config.model_type == "mpt":
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type == "dbrx":
493
494
495
            return getattr(self.hf_config.attn_config, "kv_n_heads",
                           self.hf_config.num_attention_heads)

496
497
498
499
500
501
502
503
504
505
        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
506
            num_kv_heads = getattr(self.hf_text_config, attr, None)
507
508
509
510
511
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
512
        return self.hf_text_config.num_attention_heads
513
514
515
516
517
518
519
520
521
522

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
523

524
525
    def get_num_attention_heads(self,
                                parallel_config: "ParallelConfig") -> int:
526
527
        num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
        return num_heads // parallel_config.tensor_parallel_size
528

529
    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
530
        from vllm.distributed.utils import get_pp_indices
Mor Zusman's avatar
Mor Zusman committed
531
532
        total_num_hidden_layers = getattr(self.hf_text_config,
                                          "num_hidden_layers", 0)
533
534
535
536
        pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
        pp_size = parallel_config.pipeline_parallel_size
        start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
        return end - start
537

Mor Zusman's avatar
Mor Zusman committed
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    def contains_seqlen_agnostic_layers(
            self, parallel_config: "ParallelConfig") -> bool:
        """True for Mamba/SSM models (Jamba)"""
        return self._get_num_seqlen_agnostic_layers(parallel_config) > 0

    def get_layers_block_type(self,
                              parallel_config: "ParallelConfig") -> List[str]:
        num_layers = self.get_num_layers(parallel_config)
        # Transformers supports layers_block_type @property
        return getattr(self.hf_config, "layers_block_type",
                       ["attention"] * num_layers)

    def get_num_attention_layers(self,
                                 parallel_config: "ParallelConfig") -> int:
        return len([
            t for t in self.get_layers_block_type(parallel_config)
            if t == "attention"
        ])

    def _get_num_seqlen_agnostic_layers(
            self, parallel_config: "ParallelConfig") -> int:
        return len([
            t for t in self.get_layers_block_type(parallel_config)
            if t != "attention"
        ])

564
565
566
567
568
569
570
571
572
573
574
575
    def get_multimodal_config(self) -> "MultiModalConfig":
        """
        Get the multimodal configuration of the model.

        Raises:
            ValueError: If the model is not multimodal.
        """
        if self.multimodal_config is None:
            raise ValueError("The model is not multimodal.")

        return self.multimodal_config

576
577
578
    @property
    def is_encoder_decoder_model(self) -> bool:
        """Extract the HF encoder/decoder model flag."""
579
580
581
        return getattr(self.hf_config, "is_encoder_decoder", False) or (
            (hasattr(self.hf_config, "text_config") and getattr(
                self.hf_config.text_config, "is_encoder_decoder", False)))
582
583
584
585
586
587

    @property
    def is_embedding_model(self) -> bool:
        """Extract the embedding model flag."""
        return self.embedding_mode

588
589
590
591
    @property
    def is_multimodal_model(self) -> bool:
        return self.multimodal_config is not None

592
593

class CacheConfig:
594
595
596
597
598
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
599
            vLLM execution.
600
        swap_space: Size of the CPU swap space per GPU (in GiB).
601
        cache_dtype: Data type for kv cache storage.
602
        num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
603
            profiled num_gpu_blocks if specified. Does nothing if None.
604
    """
605

606
607
608
609
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
610
        swap_space: float,
611
        cache_dtype: str,
612
        num_gpu_blocks_override: Optional[int] = None,
613
        sliding_window: Optional[int] = None,
614
        enable_prefix_caching: bool = False,
615
        cpu_offload_gb: float = 0,
616
617
618
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
619
        self.swap_space_bytes = swap_space * GiB_bytes
620
        self.num_gpu_blocks_override = num_gpu_blocks_override
621
        self.cache_dtype = cache_dtype
622
        self.sliding_window = sliding_window
623
        self.enable_prefix_caching = enable_prefix_caching
624
        self.cpu_offload_gb = cpu_offload_gb
625
        self._verify_args()
626
        self._verify_cache_dtype()
627
        self._verify_prefix_caching()
628
629
630
631
632

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

633
    def metrics_info(self):
634
635
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
636
637
        return {key: str(value) for key, value in self.__dict__.items()}

638
639
640
641
642
643
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

644
645
646
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
647
        elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
648
            logger.info(
649
650
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
651
652
                "Meanwhile, it may cause accuracy drop without a proper "
                "scaling factor")
653
654
655
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

656
657
658
659
660
661
662
663
664
    def _verify_prefix_caching(self) -> None:
        if not self.enable_prefix_caching:
            return

        if self.sliding_window is not None:
            raise NotImplementedError(
                "Prefix caching is not supported with sliding window. "
                "Run with --disable-sliding-window to use prefix caching.")

665
666
667
668
669
670
671
672
673
674
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

675
676
677
        msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
               f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
               "is allocated for the swap space.")
678
679
680
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
681
            logger.warning("Possibly too large swap space. %s", msg)
682

683

684
685
686
@dataclass
class TokenizerPoolConfig:
    """Configuration for the tokenizer pool.
687

688
689
690
691
692
693
694
695
    Args:
        pool_size: Number of tokenizer workers in the pool.
        pool_type: Type of the pool.
        extra_config: Additional config for the pool.
            The way the config will be used depends on the
            pool type.
    """
    pool_size: int
696
    pool_type: Union[str, Type["BaseTokenizerGroup"]]
697
698
699
    extra_config: dict

    def __post_init__(self):
700
701
        if self.pool_type not in ("ray", ) and not isinstance(
                self.pool_type, type):
702
703
704
705
706
707
708
709
710
711
            raise ValueError(f"Unknown pool type: {self.pool_type}")
        if not isinstance(self.extra_config, dict):
            raise ValueError("extra_config must be a dictionary.")

    @classmethod
    def create_config(
        cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
        tokenizer_pool_extra_config: Optional[Union[str, dict]]
    ) -> Optional["TokenizerPoolConfig"]:
        """Create a TokenizerPoolConfig from the given parameters.
712

713
        If tokenizer_pool_size is 0, return None.
714

715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
        Args:
            tokenizer_pool_size: Number of tokenizer workers in the pool.
            tokenizer_pool_type: Type of the pool.
            tokenizer_pool_extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type. This can be a JSON string (will be parsed).
        """
        if tokenizer_pool_size:
            if isinstance(tokenizer_pool_extra_config, str):
                tokenizer_pool_extra_config_parsed = json.loads(
                    tokenizer_pool_extra_config)
            else:
                tokenizer_pool_extra_config_parsed = (
                    tokenizer_pool_extra_config or {})
            tokenizer_pool_config = cls(tokenizer_pool_size,
                                        tokenizer_pool_type,
                                        tokenizer_pool_extra_config_parsed)
        else:
            tokenizer_pool_config = None
        return tokenizer_pool_config


737
738
739
740
741
742
743
class LoadFormat(str, enum.Enum):
    AUTO = "auto"
    PT = "pt"
    SAFETENSORS = "safetensors"
    NPCACHE = "npcache"
    DUMMY = "dummy"
    TENSORIZER = "tensorizer"
744
    SHARDED_STATE = "sharded_state"
745
    GGUF = "gguf"
746
    BITSANDBYTES = "bitsandbytes"
747
    MISTRAL = "mistral"
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766


@dataclass
class LoadConfig:
    """
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
            "tensorizer" will use CoreWeave's tensorizer library for
                fast weight loading.
767
            "bitsandbytes" will load nf4 type weights.
768
769
770
        ignore_patterns: The list of patterns to ignore when loading the model.
            Default to "original/**/*" to avoid repeated loading of llama's 
            checkpoints.
771

772
773
774
775
776
777
    """

    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(
        default_factory=dict)
778
    ignore_patterns: Optional[Union[List[str], str]] = None
779
780
781
782
783
784
785
786

    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = json.loads(
                model_loader_extra_config)
        self._verify_load_format()

787
788
789
790
791
792
793
        if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
            logger.info(
                "Ignoring the following patterns when downloading weights: %s",
                self.ignore_patterns)
        else:
            self.ignore_patterns = ["original/**/*"]

794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
    def _verify_load_format(self) -> None:
        if not isinstance(self.load_format, str):
            return

        load_format = self.load_format.lower()
        self.load_format = LoadFormat(load_format)

        rocm_not_supported_load_format: List[str] = []
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f for f in LoadFormat.__members__
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format '{load_format}' is not supported in ROCm. "
                f"Supported load formats are "
                f"{rocm_supported_load_format}")


813
class ParallelConfig:
814
815
816
817
818
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
819
        worker_use_ray: Deprecated, use distributed_executor_backend instead.
zspo's avatar
zspo committed
820
821
822
        max_parallel_loading_workers: Maximum number of multiple batches
            when load model sequentially. To avoid RAM OOM when using tensor
            parallel and large models.
823
824
        disable_custom_all_reduce: Disable the custom all-reduce kernel and
            fall back to NCCL.
825
826
        tokenizer_pool_config: Config for the tokenizer pool.
            If None, will use synchronous tokenization.
827
828
        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
829
        placement_group: ray distributed model workers placement group.
830
831
832
833
        distributed_executor_backend: Backend to use for distributed model
            workers, either "ray" or "mp" (multiprocessing). If either
            pipeline_parallel_size or tensor_parallel_size is greater than 1,
            will default to "ray" if Ray is installed or "mp" otherwise.
834
    """
835

836
837
838
839
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
840
        worker_use_ray: Optional[bool] = None,
841
        max_parallel_loading_workers: Optional[int] = None,
842
        disable_custom_all_reduce: bool = False,
843
        tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
844
        ray_workers_use_nsight: bool = False,
845
        placement_group: Optional["PlacementGroup"] = None,
846
847
        distributed_executor_backend: Optional[Union[
            str, Type["ExecutorBase"]]] = None,
848
849
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
850
        self.tensor_parallel_size = tensor_parallel_size
851
        self.distributed_executor_backend = distributed_executor_backend
852
        self.max_parallel_loading_workers = max_parallel_loading_workers
853
        self.disable_custom_all_reduce = disable_custom_all_reduce
854
        self.tokenizer_pool_config = tokenizer_pool_config
855
        self.ray_workers_use_nsight = ray_workers_use_nsight
856
        self.placement_group = placement_group
857
        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
858

859
860
861
        if worker_use_ray:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
862
            elif not self.use_ray:
863
864
865
866
                raise ValueError(f"worker-use-ray can't be used with "
                                 f"distributed executor backend "
                                 f"'{self.distributed_executor_backend}'.")

867
868
869
870
871
872
873
        if current_platform.is_tpu() and self.world_size > 1:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
            if self.distributed_executor_backend != "ray":
                raise ValueError(
                    "TPU backend only supports Ray for distributed inference.")

874
        if self.distributed_executor_backend is None and self.world_size > 1:
875
876
877
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

878
            from vllm.executor import ray_utils
879
            backend = "mp"
880
            ray_found = ray_utils.ray_is_available()
881
            if (current_platform.is_cuda()
882
                    and cuda_device_count_stateless() < self.world_size):
883
884
                if not ray_found:
                    raise ValueError("Unable to load Ray which is "
885
886
887
                                     "required for multi-node inference, "
                                     "please install Ray with `pip install "
                                     "ray`.") from ray_utils.ray_import_err
888
889
                backend = "ray"
            elif ray_found:
890
                if self.placement_group:
891
                    backend = "ray"
892
893
894
895
896
897
                else:
                    from ray import is_initialized as ray_is_initialized
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
                        if get_current_placement_group():
                            backend = "ray"
898
899
900
            self.distributed_executor_backend = backend
            logger.info("Defaulting to use %s for distributed inference",
                        backend)
901

902
        self._verify_args()
903
        self.rank: int = 0
904

905
906
907
908
909
910
    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
            and self.distributed_executor_backend.uses_ray)

911
    def _verify_args(self) -> None:
912
913
914
915
916
917
918
        # Lazy import to avoid circular import
        from vllm.executor.executor_base import ExecutorBase

        if self.distributed_executor_backend not in (
                "ray", "mp", None) and not (isinstance(
                    self.distributed_executor_backend, type) and issubclass(
                        self.distributed_executor_backend, ExecutorBase)):
919
            raise ValueError(
920
921
922
923
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
                "values are 'ray', 'mp' or custom ExecutorBase subclass.")
        if self.use_ray:
924
925
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()
926
927
928
929
930
        if is_hip():
            self.disable_custom_all_reduce = True
            logger.info(
                "Disabled the custom all-reduce kernel because it is not "
                "supported on AMD GPUs.")
931
        if self.ray_workers_use_nsight and not self.use_ray:
932
933
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
934

935
936

class SchedulerConfig:
937
938
939
940
941
942
943
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
944
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
945
            and generated text).
946
947
948
949
950
        use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
        num_lookahead_slots: The number of slots to allocate per sequence per
            step, beyond the known token ids. This is used in speculative
            decoding to store KV activations of tokens which may or may not be
            accepted.
951
952
        delay_factor: Apply a delay (of delay factor multiplied by previous
            prompt latency) before scheduling next prompt.
953
954
        enable_chunked_prefill: If True, prefill requests can be chunked based
            on the remaining max_num_batched_tokens.
955
        embedding_mode: Whether the running model is for embedding.
956
957
958
959
960
961
        preemption_mode: Whether to perform preemption by swapping or 
            recomputation. If not specified, we determine the mode as follows:
            We use recomputation by default since it incurs lower overhead than
            swapping. However, when the sequence group has multiple sequences
            (e.g., beam search), recomputation is not currently supported. In
            such a case, we use swapping instead.
962
963
964
965
        send_delta_data: Private API. If used, scheduler sends delta data to
            workers instead of an entire data. It should be enabled only
            when SPMD worker architecture is enabled. I.e.,
            VLLM_USE_RAY_SPMD_WORKER=1
966
        policy: The scheduling policy to use. "fcfs" (default) or "priority".
967
    """
968

969
970
971
972
    def __init__(self,
                 max_num_batched_tokens: Optional[int],
                 max_num_seqs: int,
                 max_model_len: int,
973
                 use_v2_block_manager: bool = True,
974
975
976
                 num_lookahead_slots: int = 0,
                 delay_factor: float = 0.0,
                 enable_chunked_prefill: bool = False,
977
978
                 embedding_mode: bool = False,
                 is_multimodal_model: bool = False,
979
                 preemption_mode: Optional[str] = None,
980
                 num_scheduler_steps: int = 1,
981
                 multi_step_stream_outputs: bool = False,
982
983
                 send_delta_data: bool = False,
                 policy: str = "fcfs") -> None:
984
        if max_num_batched_tokens is None:
985
            if enable_chunked_prefill:
986
987
988
989
990
991
992
993
994
995
                if num_scheduler_steps > 1:
                    # Multi-step Chunked-Prefill doesn't allow prompt-chunking
                    # for now. Have max_num_batched_tokens set to max_model_len
                    # so we don't reject sequences on account of a short
                    # max_num_batched_tokens.
                    max_num_batched_tokens = max(max_model_len, 2048)
                else:
                    # It is the values that have the best balance between ITL
                    # and TTFT on A100. Note it is not optimized for throughput.
                    max_num_batched_tokens = 512
996
997
998
            else:
                # If max_model_len is too short, use 2048 as the default value
                # for higher throughput.
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
                max_num_batched_tokens = max(max_model_len, 2048)

            if embedding_mode:
                # For embedding, choose specific value for higher throughput
                max_num_batched_tokens = max(
                    max_num_batched_tokens,
                    _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
                )
            if is_multimodal_model:
                # The value needs to be at least the number of multimodal tokens
                max_num_batched_tokens = max(
                    max_num_batched_tokens,
                    _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
                )

        self.max_num_batched_tokens = max_num_batched_tokens

1016
        if enable_chunked_prefill:
1017
1018
            logger.info(
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
1019
                self.max_num_batched_tokens)
1020

1021
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
1022
        self.max_model_len = max_model_len
1023
        self.use_v2_block_manager = use_v2_block_manager
1024
1025
        self.num_lookahead_slots = num_lookahead_slots
        self.delay_factor = delay_factor
1026
        self.chunked_prefill_enabled = enable_chunked_prefill
1027
        self.embedding_mode = embedding_mode
1028
        self.preemption_mode = preemption_mode
1029
        self.num_scheduler_steps = num_scheduler_steps
1030
        self.multi_step_stream_outputs = multi_step_stream_outputs
1031
        self.send_delta_data = send_delta_data
1032
        self.policy = policy
1033
1034
1035
        self._verify_args()

    def _verify_args(self) -> None:
1036
1037
        if (self.max_num_batched_tokens < self.max_model_len
                and not self.chunked_prefill_enabled):
1038
1039
1040
1041
1042
1043
1044
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
1045

1046
1047
1048
1049
1050
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
1051

1052
1053
1054
1055
1056
1057
        if self.num_lookahead_slots < 0:
            raise ValueError(
                "num_lookahead_slots "
                f"({self.num_lookahead_slots}) must be greater than or "
                "equal to 0.")

1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
        if self.num_scheduler_steps < 1:
            raise ValueError(
                "num_scheduler_steps "
                f"({self.num_scheduler_steps}) must be greater than or "
                "equal to 1.")

    @property
    def is_multi_step(self) -> bool:
        return self.num_scheduler_steps > 1

1068

1069
class DeviceConfig:
1070
    device: Optional[torch.device]
1071

1072
1073
1074
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
1075
1076
1077
            if current_platform.is_cuda_alike():
                self.device_type = "cuda"
            elif is_neuron():
1078
                self.device_type = "neuron"
1079
1080
            elif is_openvino():
                self.device_type = "openvino"
1081
            elif current_platform.is_tpu():
1082
                self.device_type = "tpu"
1083
            elif current_platform.is_cpu():
1084
                self.device_type = "cpu"
1085
1086
            elif is_xpu():
                self.device_type = "xpu"
1087
            else:
1088
                raise RuntimeError("Failed to infer device type")
1089
1090
1091
1092
1093
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
1094
        if self.device_type in ["neuron", "openvino"]:
1095
            self.device = torch.device("cpu")
1096
1097
        elif self.device_type in ["tpu"]:
            self.device = None
1098
1099
1100
1101
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

1102

1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
class SpeculativeConfig:
    """Configuration for speculative decoding.

    The configuration is currently specialized to draft-model speculative
    decoding with top-1 proposals.
    """

    @staticmethod
    def maybe_create_spec_config(
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        target_dtype: str,
        speculative_model: Optional[str],
1116
        speculative_model_quantization: Optional[str],
1117
        speculative_draft_tensor_parallel_size: Optional[int],
1118
        num_speculative_tokens: Optional[int],
1119
        speculative_disable_mqa_scorer: Optional[bool],
1120
1121
1122
        speculative_max_model_len: Optional[int],
        enable_chunked_prefill: bool,
        use_v2_block_manager: bool,
1123
        disable_log_stats: bool,
1124
        speculative_disable_by_batch_size: Optional[int],
1125
1126
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1127
1128
1129
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: Optional[float],
        typical_acceptance_sampler_posterior_alpha: Optional[float],
1130
        disable_logprobs: Optional[bool],
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    ) -> Optional["SpeculativeConfig"]:
        """Create a SpeculativeConfig if possible, else return None.

        This function attempts to create a SpeculativeConfig object based on the
        provided parameters. If the necessary conditions are met, it returns an
        instance of SpeculativeConfig. Otherwise, it returns None.

        Args:
            target_model_config (ModelConfig): The configuration of the target
                model.
            target_parallel_config (ParallelConfig): The parallel configuration
                for the target model.
            target_dtype (str): The data type used for the target model.
            speculative_model (Optional[str]): The name of the speculative
                model, if provided.
1146
1147
1148
            speculative_model_quantization (Optional[str]): Quantization method
                that was used to quantize the speculative model weights. If
                None, we assume the model weights are not quantized.
1149
1150
            speculative_draft_tensor_parallel_size (Optional[int]): The degree
                of the tensor parallelism for the draft model.
1151
            num_speculative_tokens (Optional[int]): The number of speculative
1152
1153
                tokens, if provided. Will default to the number in the draft
                model config if present, otherwise is required.
1154
1155
1156
            speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
                scorer for the speculative model and fall back to batch
                expansion for scoring.
1157
1158
1159
1160
1161
1162
1163
1164
1165
            speculative_max_model_len (Optional[int]): The maximum model len of
                the speculative model. Used when testing the ability to skip
                speculation for some sequences.
            enable_chunked_prefill (bool): Whether vLLM is configured to use
                chunked prefill or not. Used for raising an error since its not
                yet compatible with spec decode.
            use_v2_block_manager (bool): Whether vLLM is configured to use the
                v2 block manager or not. Used for raising an error since the v2
                block manager is required with spec decode.
1166
1167
1168
            speculative_disable_by_batch_size (Optional[int]): Disable
                speculative decoding for new incoming requests when the number
                of enqueue requests  is larger than this value, if provided.
1169
1170
1171
1172
            ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
                window, if provided.
            ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
                window, if provided.
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
                accepted. This threshold is used only when we use the 
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1186
1187
1188
1189
1190
            disable_logprobs (Optional[bool]): If set to True, token log
                probabilities are not returned during speculative decoding.
                If set to False, token log probabilities are returned
                according to the log probability settings in SamplingParams.
                If not specified, it defaults to True.
1191
    
1192
1193
1194
1195
1196
        Returns:
            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                the necessary conditions are met, else None.
        """

1197
1198
1199
1200
        if speculative_model is None:
            if num_speculative_tokens is not None:
                raise ValueError("num_speculative_tokens was provided without "
                                 "speculative_model.")
1201
1202
            return None

1203
1204
1205
1206
1207
1208
        if (speculative_disable_by_batch_size is not None
                and speculative_disable_by_batch_size < 2):
            raise ValueError("Expect the batch size threshold of disabling "
                             "speculative decoding is > 1, but got "
                             f"{speculative_disable_by_batch_size=}")

1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
        if enable_chunked_prefill:
            raise ValueError(
                "Speculative decoding and chunked prefill are "
                f"currently mutually exclusive ({enable_chunked_prefill=}).")

        if not use_v2_block_manager:
            raise ValueError(
                "Speculative decoding requires usage of the V2 "
                "block manager. Enable it with --use-v2-block-manager.")

1219
1220
        # TODO: The user should be able to specify revision/max model len
        # for the draft model. It is not currently supported.
1221
1222
        draft_revision = None
        draft_code_revision = None
1223
        draft_quantization = speculative_model_quantization
1224

1225
1226
        if speculative_model == "[ngram]":
            if ngram_prompt_lookup_min is None:
1227
1228
1229
1230
1231
1232
1233
1234
                ngram_prompt_lookup_min = 1
            if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
                raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
            if ngram_prompt_lookup_min < 1:
                raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
            if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
                raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
                                 f"larger than {ngram_prompt_lookup_max=}")
1235

1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            draft_model_config = target_model_config
            draft_parallel_config = target_parallel_config
        else:
            ngram_prompt_lookup_max = 0
            ngram_prompt_lookup_min = 0
            draft_model_config = ModelConfig(
                model=speculative_model,
                tokenizer=target_model_config.tokenizer,
                tokenizer_mode=target_model_config.tokenizer_mode,
                trust_remote_code=target_model_config.trust_remote_code,
                dtype=target_model_config.dtype,
                seed=target_model_config.seed,
                revision=draft_revision,
                code_revision=draft_code_revision,
                tokenizer_revision=target_model_config.tokenizer_revision,
                max_model_len=None,
1255
                spec_target_max_model_len=target_model_config.max_model_len,
1256
1257
                quantization=draft_quantization,
                enforce_eager=target_model_config.enforce_eager,
1258
1259
                max_seq_len_to_capture=target_model_config.
                max_seq_len_to_capture,
1260
1261
1262
                max_logprobs=target_model_config.max_logprobs,
            )

1263
            draft_hf_config = draft_model_config.hf_config
1264

1265
1266
1267
1268
1269
            if (num_speculative_tokens is not None
                    and hasattr(draft_hf_config, "num_lookahead_tokens")):
                draft_hf_config.num_lookahead_tokens = num_speculative_tokens

            n_predict = getattr(draft_hf_config, "n_predict", None)
1270
1271
1272
1273
1274
1275
1276
1277
            if n_predict is not None:
                if num_speculative_tokens is None:
                    # Default to max value defined in draft model config.
                    num_speculative_tokens = n_predict
                elif num_speculative_tokens > n_predict:
                    # Verify provided value doesn't exceed the maximum
                    # supported by the draft model.
                    raise ValueError(
1278
1279
1280
                        "This speculative model supports a maximum of "
                        f"num_speculative_tokens={n_predict}, but "
                        f"{num_speculative_tokens=} was provided.")
1281

1282
1283
1284
1285
1286
1287
1288
1289
1290
            draft_model_config.max_model_len = (
                SpeculativeConfig._maybe_override_draft_max_model_len(
                    speculative_max_model_len,
                    draft_model_config.max_model_len,
                    target_model_config.max_model_len,
                ))

            draft_parallel_config = (
                SpeculativeConfig.create_draft_parallel_config(
1291
                    target_parallel_config,
1292
                    speculative_draft_tensor_parallel_size, draft_hf_config))
1293

1294
1295
1296
1297
1298
1299
        if num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative_model unless the draft model config contains an "
                "n_predict parameter.")

1300
1301
1302
1303
        if typical_acceptance_sampler_posterior_threshold is None:
            typical_acceptance_sampler_posterior_threshold = 0.09
        if typical_acceptance_sampler_posterior_alpha is None:
            typical_acceptance_sampler_posterior_alpha = 0.3
1304
1305
        if disable_logprobs is None:
            disable_logprobs = True
1306

1307
1308
1309
1310
        return SpeculativeConfig(
            draft_model_config,
            draft_parallel_config,
            num_speculative_tokens,
1311
            speculative_disable_mqa_scorer,
1312
            speculative_disable_by_batch_size,
1313
1314
            ngram_prompt_lookup_max,
            ngram_prompt_lookup_min,
1315
1316
1317
1318
1319
            draft_token_acceptance_method=draft_token_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=\
                typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=\
                typical_acceptance_sampler_posterior_alpha,
1320
1321
            disable_logprobs=disable_logprobs,
            disable_log_stats=disable_log_stats,
1322
1323
        )

1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
    @staticmethod
    def _maybe_override_draft_max_model_len(
        speculative_max_model_len: Optional[int],
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:

            if speculative_max_model_len > draft_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {draft_max_model_len=}")

            if speculative_max_model_len > target_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {target_max_model_len=}")

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

1359
1360
    @staticmethod
    def create_draft_parallel_config(
1361
        target_parallel_config: ParallelConfig,
1362
1363
        speculative_draft_tensor_parallel_size: Optional[int],
        draft_hf_config: PretrainedConfig,
1364
    ) -> ParallelConfig:
1365
1366
        """Create a parallel config for use by the draft worker.

1367
        This is mostly a copy of the target parallel config, except the tp_size.
1368
        """
1369
        if speculative_draft_tensor_parallel_size is None:
1370
1371
1372
1373
1374
1375
1376
1377
1378
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "MLPSpeculator cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1")
            else:
                speculative_draft_tensor_parallel_size = \
                    target_parallel_config.tensor_parallel_size
1379
1380
1381
        elif speculative_draft_tensor_parallel_size != 1:
            # TODO(wooyeon): allow tp values larger than 1
            raise ValueError(
1382
                f"{speculative_draft_tensor_parallel_size=} cannot be "
1383
1384
                f"other value than 1")

1385
1386
1387
        draft_parallel_config = ParallelConfig(
            pipeline_parallel_size=target_parallel_config.
            pipeline_parallel_size,
1388
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
1389
1390
            distributed_executor_backend=target_parallel_config.
            distributed_executor_backend,
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
            max_parallel_loading_workers=target_parallel_config.
            max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.
            disable_custom_all_reduce,
            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
            ray_workers_use_nsight=target_parallel_config.
            ray_workers_use_nsight,
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

    def __init__(
        self,
        draft_model_config: ModelConfig,
        draft_parallel_config: ParallelConfig,
        num_speculative_tokens: int,
1408
        speculative_disable_mqa_scorer: Optional[bool],
1409
1410
1411
        speculative_disable_by_batch_size: Optional[int],
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1412
1413
1414
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
1415
        disable_logprobs: bool,
1416
        disable_log_stats: bool,
1417
1418
1419
1420
1421
1422
1423
1424
    ):
        """Create a SpeculativeConfig object.

        Args:
            draft_model_config: ModelConfig for the draft model.
            draft_parallel_config: ParallelConfig for the draft model.
            num_speculative_tokens: The number of tokens to sample from the
                draft model before scoring with the target model.
1425
1426
1427
1428
1429
            speculative_disable_by_batch_size: Disable speculative
                decoding for new incoming requests when the number of
                enqueue requests is larger than this value.
            ngram_prompt_lookup_max: Max size of ngram token window.
            ngram_prompt_lookup_min: Min size of ngram token window.
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
                accepted. This threshold is used only when we use the 
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1443
1444
1445
1446
1447
1448
            disable_logprobs: If set to True, token log probabilities will not
                be returned even if requested by sampling parameters. This 
                reduces latency by skipping logprob calculation in proposal
                sampling, target sampling, and after accepted tokens are
                determined. If set to False, log probabilities will be
                returned.
1449
1450
            disable_log_stats: Whether to disable periodic printing of stage
                times in speculative decoding.
1451
1452
1453
1454
        """
        self.draft_model_config = draft_model_config
        self.draft_parallel_config = draft_parallel_config
        self.num_speculative_tokens = num_speculative_tokens
1455
        self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
1456
1457
1458
1459
        self.speculative_disable_by_batch_size = \
            speculative_disable_by_batch_size
        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
1460
1461
1462
1463
1464
        self.draft_token_acceptance_method = draft_token_acceptance_method
        self.typical_acceptance_sampler_posterior_threshold = \
            typical_acceptance_sampler_posterior_threshold
        self.typical_acceptance_sampler_posterior_alpha = \
            typical_acceptance_sampler_posterior_alpha
1465
        self.disable_logprobs = disable_logprobs
1466
        self.disable_log_stats = disable_log_stats
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477

        self._verify_args()

    def _verify_args(self) -> None:
        if self.num_speculative_tokens <= 0:
            raise ValueError("Expected num_speculative_tokens to be greater "
                             f"than zero ({self.num_speculative_tokens}).")

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
                self.draft_parallel_config)
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
            # Validate and set draft token acceptance related settings.

        if (self.draft_token_acceptance_method is None):
            raise ValueError("draft_token_acceptance_method is not set. "
                             "Expected values are rejection_sampler or "
                             "typical_acceptance_sampler.")

        if (self.draft_token_acceptance_method != 'rejection_sampler'
                and self.draft_token_acceptance_method !=
                'typical_acceptance_sampler'):
            raise ValueError(
                "Expected draft_token_acceptance_method to be either "
                "rejection_sampler or typical_acceptance_sampler. Instead it "
                f"is {self.draft_token_acceptance_method}")

        if (self.typical_acceptance_sampler_posterior_threshold < 0
                or self.typical_acceptance_sampler_posterior_alpha < 0):
            raise ValueError(
                "Expected typical_acceptance_sampler_posterior_threshold "
                "and typical_acceptance_sampler_posterior_alpha to be > 0. "
                "Instead found "
                f"typical_acceptance_sampler_posterior_threshold = "
                f"{self.typical_acceptance_sampler_posterior_threshold} and "
                f"typical_acceptance_sampler_posterior_alpha = "
                f"{self.typical_acceptance_sampler_posterior_alpha}")
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def __repr__(self) -> str:
1515
1516
1517
1518
        if self.ngram_prompt_lookup_max > 0:
            draft_model = "[ngram]"
        else:
            draft_model = self.draft_model_config.model
1519
1520
1521
1522
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


1523
1524
1525
1526
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
1527
    fully_sharded_loras: bool = False
1528
1529
1530
1531
1532
    max_cpu_loras: Optional[int] = None
    lora_dtype: Optional[torch.dtype] = None
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256
1533
    long_lora_scaling_factors: Optional[Tuple[float]] = None
1534
1535

    def __post_init__(self):
1536
1537
1538
        # Setting the maximum rank to 256 should be able to satisfy the vast
        # majority of applications.
        possible_max_ranks = (8, 16, 32, 64, 128, 256)
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
1555
                f"max_loras ({self.max_loras})")
1556
1557
1558
1559
1560
1561

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
1562
1563
1564
        if model_config.quantization and model_config.quantization not in [
                "awq", "gptq"
        ]:
1565
            # TODO support marlin
1566
1567
            logger.warning("%s quantization is not tested with LoRA yet.",
                           model_config.quantization)
1568
1569

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
1570
1571
        if scheduler_config.chunked_prefill_enabled:
            raise ValueError("LoRA is not supported with chunked prefill yet.")
1572
1573


1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
@dataclass
class PromptAdapterConfig:
    max_prompt_adapters: int
    max_prompt_adapter_token: int
    max_cpu_prompt_adapters: Optional[int] = None
    prompt_adapter_dtype: Optional[torch.dtype] = None

    def __post_init__(self):

        if self.max_prompt_adapters < 1:
            raise ValueError(f"max_prompt_adapters "
                             f"({self.max_prompt_adapters}) must be >= 1.")
        if self.max_prompt_adapter_token == 0:
            raise ValueError("max_prompt_adapter_token must be set.")
        if self.max_cpu_prompt_adapters is None:
            self.max_cpu_prompt_adapters = self.max_prompt_adapters

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.prompt_adapter_dtype in (None, "auto"):
            self.prompt_adapter_dtype = model_config.dtype
        elif isinstance(self.prompt_adapter_dtype, str):
            self.prompt_adapter_dtype = getattr(torch,
                                                self.prompt_adapter_dtype)


1599
@dataclass
1600
class MultiModalConfig:
1601
1602
    """Controls the behavior of multimodal models."""

1603
    limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
1604
1605
1606
1607
1608
    """
    The maximum number of multi-modal input instances allowed per prompt
    for each :class:`~vllm.multimodal.MultiModalPlugin`.
    """

1609
    # TODO: Add configs to init vision tower or not.
1610

1611

1612
1613
1614
1615
1616
1617
1618
1619
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

1620
_ROCM_NOT_SUPPORTED_DTYPE: List[str] = []  #
1621

1622
1623
1624

def _get_and_verify_dtype(
    config: PretrainedConfig,
1625
    dtype: Union[str, torch.dtype],
1626
1627
1628
1629
1630
1631
1632
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

1633
1634
1635
1636
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
Woosuk Kwon's avatar
Woosuk Kwon committed
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
                if config.model_type == "gemma2":
                    logger.info(
                        "For Gemma 2, we downcast float32 to bfloat16 instead "
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16.")
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
1647
1648
            else:
                torch_dtype = config_dtype
1649
        else:
1650
1651
1652
1653
1654
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
1655
    else:
1656
        raise ValueError(f"Unknown dtype: {dtype}")
1657
1658
1659
1660
1661

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
1662
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
1663
1664
1665
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
1666
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
1667
1668
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
1669
            # Casting between float16 and bfloat16 is allowed with a warning.
1670
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
1671
1672

    return torch_dtype
1673
1674
1675
1676
1677


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
1678
1679
    disable_sliding_window: bool,
    sliding_window_len: Optional[int],
1680
    spec_target_max_model_len: Optional[int] = None,
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
1691
1692
        # ChatGLM2
        "seq_length",
1693
1694
        # Command-R
        "model_max_length",
1695
1696
1697
1698
1699
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
1700
    # Choose the smallest "max_length" from the possible keys.
1701
    max_len_key = None
1702
    for key in possible_keys:
1703
1704
1705
1706
1707
        max_len = getattr(hf_config, key, None)
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717

    # If sliding window is manually disabled, max_length should be less
    # than the sliding window length in the model config.
    if disable_sliding_window and sliding_window_len is not None:
        max_len_key = "sliding_window" \
            if sliding_window_len < derived_max_model_len else max_len_key
        derived_max_model_len = min(derived_max_model_len, sliding_window_len)

    # If none of the keys were found in the config, use a default and
    # log a warning.
1718
    if derived_max_model_len == float("inf"):
1719
1720
1721
1722
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

1723
1724
1725
1726
1727
        if spec_target_max_model_len is not None:
            # If this is a speculative draft model, we use the max model len
            # from the target model.
            return spec_target_max_model_len

1728
1729
1730
1731
        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
1732
            "%s. Assuming the model's maximum length is %d.", possible_keys,
1733
            default_max_len)
1734
        derived_max_model_len = default_max_len
1735

1736
    rope_scaling = getattr(hf_config, "rope_scaling", None)
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
    if rope_scaling is not None:
        if "type" in rope_scaling:
            rope_type = rope_scaling["type"]
        elif "rope_type" in rope_scaling:
            rope_type = rope_scaling["rope_type"]
        else:
            raise ValueError(
                "rope_scaling must have a 'type' or 'rope_type' key.")

        # The correct one should be "longrope", kept "su" here
        # to be backward compatible
        if rope_type not in ("su", "longrope", "llama3"):
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that supports rope_scaling
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "with rope_scaling. Please raise an issue so we can "
                    "investigate.")

1757
1758
1759
1760
1761
            if rope_type == "mrope":
                scaling_factor = 1
            else:
                assert "factor" in rope_scaling
                scaling_factor = rope_scaling["factor"]
1762
1763
1764
1765
            if rope_type == "yarn":
                derived_max_model_len = rope_scaling[
                    "original_max_position_embeddings"]
            derived_max_model_len *= scaling_factor
1766

1767
1768
    # If the user specified a max length, make sure it is smaller than the
    # derived length from the HF model config.
1769
    if max_model_len is None:
1770
        max_model_len = int(derived_max_model_len)
1771
    elif max_model_len > derived_max_model_len:
1772
1773
1774
1775
1776
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, "model_max_length", None)
        if model_max_length is not None and max_model_len <= model_max_length:
1777
1778
1779
1780
1781
1782
1783
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that has model_max_length
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "model_max_length in the config. Please raise an issue "
                    "so we can investigate.")
1784
        else:
1785
            msg = (
1786
                f"User-specified max_model_len ({max_model_len}) is greater "
1787
1788
                f"than the derived max_model_len ({max_len_key}="
                f"{derived_max_model_len} or model_max_length="
1789
                f"{model_max_length} in model's config.json). This may lead "
1790
1791
1792
1793
1794
1795
1796
1797
1798
                "to incorrect model outputs or CUDA errors.")
            if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
                logger.warning(
                    "%s Make sure the value is correct and within the "
                    "model context size.", msg)
            else:
                raise ValueError(
                    f"{msg} To allow overriding this maximum, set "
                    "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
1799
    return int(max_model_len)
1800
1801


1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
def get_served_model_name(model: str,
                          served_model_name: Optional[Union[str, List[str]]]):
    """
    If the input is a non-empty list, the first model_name in 
    `served_model_name` is taken. 
    If the input is a non-empty string, it is used directly. 
    For cases where the input is either an empty string or an 
    empty list, the fallback is to use `self.model`.
    """
    if not served_model_name:
        return model
    if isinstance(served_model_name, list):
        return served_model_name[0]
    return served_model_name


1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
@dataclass
class DecodingConfig:
    """Dataclass which contains the decoding strategy of the engine"""

    # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
    guided_decoding_backend: str = 'outlines'

    def __post_init__(self):
        valid_guided_backends = ['outlines', 'lm-format-enforcer']
        backend = self.guided_decoding_backend
        if backend not in valid_guided_backends:
            raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                             f"must be one of {valid_guided_backends}")


1833
1834
1835
1836
1837
@dataclass
class ObservabilityConfig:
    """Configuration for observability."""
    otlp_traces_endpoint: Optional[str] = None

1838
1839
1840
1841
1842
1843
1844
1845
    # Collecting detailed timing information for each request can be expensive.

    # If set, collects the model forward time for the request.
    collect_model_forward_time: bool = False

    # If set, collects the model execute time for the request.
    collect_model_execute_time: bool = False

1846
    def __post_init__(self):
1847
1848
1849
1850
1851
        if not is_otel_available() and self.otlp_traces_endpoint is not None:
            raise ValueError(
                "OpenTelemetry is not available. Unable to configure "
                "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
                f"installed. Original error:\n{otel_import_error_traceback}")
1852

1853
1854
1855
1856
1857
1858
1859
        if ((self.collect_model_forward_time
             or self.collect_model_execute_time)
                and self.otlp_traces_endpoint is None):
            raise ValueError(
                "collect_model_forward_time or collect_model_execute_time "
                "requires --otlp-traces-endpoint to be set.")

1860

1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
@dataclass(frozen=True)
class EngineConfig:
    """Dataclass which contains all engine-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    model_config: ModelConfig
    cache_config: CacheConfig
    parallel_config: ParallelConfig
    scheduler_config: SchedulerConfig
    device_config: DeviceConfig
1872
    load_config: LoadConfig
1873
1874
    lora_config: Optional[LoRAConfig]
    speculative_config: Optional[SpeculativeConfig]
1875
    decoding_config: Optional[DecodingConfig]
1876
    observability_config: Optional[ObservabilityConfig]
1877
    prompt_adapter_config: Optional[PromptAdapterConfig]
1878
1879
1880
1881

    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
1882
1883
1884
        self.model_config.verify_async_output_proc(self.parallel_config,
                                                   self.speculative_config,
                                                   self.device_config)
1885
1886
1887
1888
1889
1890
1891
        self.model_config.verify_with_parallel_config(self.parallel_config)
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
1892
1893
1894
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
1895
1896
1897
1898
1899
1900

    def to_dict(self):
        """Return the configs as a dictionary, for use in **kwargs.
        """
        return dict(
            (field.name, getattr(self, field.name)) for field in fields(self))