config.py 80.3 KB
Newer Older
1
import enum
2
import json
3
from dataclasses import dataclass, field, fields
4
5
from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple,
                    Type, Union)
6
7

import torch
8
from transformers import PretrainedConfig
9

10
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
from vllm.model_executor.models import ModelRegistry
14
from vllm.platforms import current_platform
15
from vllm.tracing import is_otel_available, otel_import_error_traceback
16
17
18
from vllm.transformers_utils.config import (get_config,
                                            get_hf_image_processor_config,
                                            get_hf_text_config)
19
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
20
                        cuda_device_count_stateless, get_cpu_memory, is_cpu,
21
                        is_hip, is_neuron, is_openvino, is_xpu,
22
                        print_warning_once)
23

24
25
26
if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

27
    from vllm.executor.executor_base import ExecutorBase
28
    from vllm.model_executor.model_loader.loader import BaseModelLoader
29
30
    from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
        BaseTokenizerGroup)
31

32
33
logger = init_logger(__name__)

34
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
35

36
37
38
_PP_SUPPORTED_MODELS = [
    "AquilaModel",
    "AquilaForCausalLM",
39
    "DeepseekV2ForCausalLM",
40
    "InternLMForCausalLM",
41
    "JAISLMHeadModel",
42
43
44
45
46
    "LlamaForCausalLM",
    "LLaMAForCausalLM",
    "MistralForCausalLM",
    "Phi3ForCausalLM",
    "GPT2LMHeadModel",
47
    "MixtralForCausalLM",
48
    "NemotronForCausalLM",
49
50
    "Qwen2ForCausalLM",
    "Qwen2MoeForCausalLM",
51
    "QWenLMHeadModel",
52
53
]

54
55

class ModelConfig:
56
57
58
59
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
60
61
            It is also used as the content for `model_name` tag in metrics 
            output when `served_model_name` is not specified. 
62
        tokenizer: Name or path of the huggingface tokenizer to use.
63
64
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
            available, and "slow" will always use the slow tokenizer.
65
66
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
67
68
69
70
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
71
72
73
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
74
        code_revision: The specific revision to use for the model code on
75
            Hugging Face Hub. It can be a branch name, a tag name, or a
76
            commit id. If unspecified, will use the default version.
77
78
79
        rope_scaling: Dictionary containing the scaling configuration for the
            RoPE embeddings. When using this flag, don't update
            `max_position_embeddings` to the expected new maximum.
80
81
82
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
83
84
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
85
86
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
87
88
        quantization_param_path: Path to JSON file containing scaling factors.
            Used to load KV cache scaling factors into the model when KV cache
89
90
            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
            be used to load activation and weight scaling factors when the
91
            model dtype is FP8_E4M3 on ROCm.
92
93
94
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
95
96
97
            If None, the user did not specify, so default to False -
            except for encoder/decoder models, which currently require
            eager mode.
98
99
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
100
101
102
103
            to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode
104
105
106
107
        disable_sliding_window: Whether to disable sliding window. If True,
            we will disable the sliding window functionality of the model.
            If the model does not support sliding window, this argument is
            ignored.
108
109
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer.
110
111
112
113
        served_model_name: The model name used in metrics tag `model_name`,
            matches the model name exposed via the APIs. If multiple model 
            names provided, the first name will be used. If not specified, 
            the model name will be the same as `model`.
114
115
        limit_mm_per_prompt: Maximum number of data instances per modality 
            per prompt. Only applicable for multimodal models.
116
    """
117
118
119
120

    def __init__(
        self,
        model: str,
121
122
        tokenizer: str,
        tokenizer_mode: str,
123
        trust_remote_code: bool,
124
        dtype: Union[str, torch.dtype],
125
        seed: int,
126
        revision: Optional[str] = None,
127
        code_revision: Optional[str] = None,
128
        rope_scaling: Optional[dict] = None,
129
        rope_theta: Optional[float] = None,
130
        tokenizer_revision: Optional[str] = None,
131
        max_model_len: Optional[int] = None,
132
        spec_target_max_model_len: Optional[int] = None,
133
        quantization: Optional[str] = None,
134
        quantization_param_path: Optional[str] = None,
135
        enforce_eager: Optional[bool] = None,
136
        max_context_len_to_capture: Optional[int] = None,
137
        max_seq_len_to_capture: Optional[int] = None,
138
        max_logprobs: int = 20,
139
        disable_sliding_window: bool = False,
140
        skip_tokenizer_init: bool = False,
141
        served_model_name: Optional[Union[str, List[str]]] = None,
142
        limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
143
        use_async_output_proc: bool = True,
144
145
    ) -> None:
        self.model = model
146
        self.tokenizer = tokenizer
147
        self.tokenizer_mode = tokenizer_mode
148
        self.trust_remote_code = trust_remote_code
149
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
150
        self.revision = revision
151
        self.code_revision = code_revision
152
        self.rope_scaling = rope_scaling
153
        self.rope_theta = rope_theta
154
155
156
157
158
        # The tokenizer version is consistent with the model version by default.
        if tokenizer_revision is None:
            self.tokenizer_revision = revision
        else:
            self.tokenizer_revision = tokenizer_revision
159
        self.quantization = quantization
160
        self.quantization_param_path = quantization_param_path
161
        self.enforce_eager = enforce_eager
162
        if max_context_len_to_capture is not None:
163
164
            raise ValueError("`max_context_len_to_capture` is deprecated. "
                             "Use `max_seq_len_to_capture` instead.")
165
        self.max_seq_len_to_capture = max_seq_len_to_capture
166
        self.max_logprobs = max_logprobs
167
        self.disable_sliding_window = disable_sliding_window
168
        self.skip_tokenizer_init = skip_tokenizer_init
169

170
        self.hf_config = get_config(self.model, trust_remote_code, revision,
171
                                    code_revision, rope_scaling, rope_theta)
172
        self.hf_text_config = get_hf_text_config(self.hf_config)
173
174
        self.hf_image_processor_config = get_hf_image_processor_config(
            self.model, revision)
175
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
176
        self.use_async_output_proc = use_async_output_proc
Woosuk Kwon's avatar
Woosuk Kwon committed
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        # Choose a default enforce_eager value if the user did not specify
        # a value (enforce_eager is None)
        if getattr(self.hf_config, 'is_encoder_decoder', False):
            if self.enforce_eager is None:
                # *Only for encoder/decoder models* and
                # *only if enforce_eager is unset*, override
                # to enforce_eager=True
                #
                # Add a logger message since it is *somewhat* non-intuitive that
                # enforce_eager is True when the user has not specified its
                # value.
                logger.info("Forcing enforce_eager == True because "
                            "enforce_eager setting was unspecified and "
                            "CUDAGraph is not supported with encoder/ "
                            "decoder models.")
                self.enforce_eager = True

            if not self.enforce_eager:
                # Eager mode explicitly disabled by user for an encoder/
                # decoder model; however CUDAGRAPH + encoder/decoder is
                # not currently supported
                raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH)
        elif self.enforce_eager is None:
            # *Only for decoder-only models*, enforce_eager
            # defaults to False if unset. This is intuitive
            # so no logging message needed.
            self.enforce_eager = False

Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208
209
210
211
212
213
214
215
        if (not self.disable_sliding_window
                and self.hf_text_config.model_type == "gemma2"
                and self.hf_text_config.sliding_window is not None):
            print_warning_once(
                "Gemma 2 uses sliding window attention for every odd layer, "
                "which is currently not supported by vLLM. Disabling sliding "
                "window and capping the max length to the sliding window size "
                f"({self.hf_text_config.sliding_window}).")
            self.disable_sliding_window = True

216
217
218
219
        self.max_model_len = _get_and_verify_max_len(
            hf_config=self.hf_text_config,
            max_model_len=max_model_len,
            disable_sliding_window=self.disable_sliding_window,
220
221
            sliding_window_len=self.get_hf_config_sliding_window(),
            spec_target_max_model_len=spec_target_max_model_len)
222
223
        self.served_model_name = get_served_model_name(model,
                                                       served_model_name)
224
225
        self.multimodal_config = self._init_multimodal_config(
            limit_mm_per_prompt)
226
227
        if not self.skip_tokenizer_init:
            self._verify_tokenizer_mode()
228
        self._verify_embedding_mode()
229
        self._verify_quantization()
230
        self._verify_cuda_graph()
231

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    def _init_multimodal_config(
        self, limit_mm_per_prompt: Optional[Mapping[str, int]]
    ) -> Optional["MultiModalConfig"]:
        architectures = getattr(self.hf_config, "architectures", [])
        if any(
                ModelRegistry.is_multimodal_model(arch)
                for arch in architectures):
            return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
        else:
            if limit_mm_per_prompt:
                raise ValueError(
                    "limit_mm_per_prompt is only supported for multimodal "
                    "models.")
            return None

247
248
249
250
251
252
253
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
        if tokenizer_mode not in ["auto", "slow"]:
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                "either 'auto' or 'slow'.")
        self.tokenizer_mode = tokenizer_mode
254

255
256
257
258
259
    def _verify_embedding_mode(self) -> None:
        architectures = getattr(self.hf_config, "architectures", [])
        self.embedding_mode = any(
            ModelRegistry.is_embedding_model(arch) for arch in architectures)

260
261
262
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
263
            # compressed-tensors uses a "compression_config" key
264
            quant_cfg = getattr(self.hf_config, "compression_config", None)
265
266
        return quant_cfg

267
    def _verify_quantization(self) -> None:
268
        supported_quantization = [*QUANTIZATION_METHODS]
269
        rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
270
271
        optimized_quantization_methods = [
            "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
272
273
            "fbgemm_fp8", "compressed_tensors", "compressed-tensors",
            "experts_int8"
274
        ]
275
        tpu_supported_quantization = ["tpu_int8"]
276
277
278
279
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
280
281
        quant_cfg = self._parse_quant_hf_config()

282
283
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
284
285

            # Detect which checkpoint is it
286
            for _, method in QUANTIZATION_METHODS.items():
287
288
289
290
291
292
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization)
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break
293

294
            # Verify quantization configurations.
295
            if self.quantization is None:
296
297
                self.quantization = quant_method
            elif self.quantization != quant_method:
298
299
                raise ValueError(
                    "Quantization method specified in the model config "
300
                    f"({quant_method}) does not match the quantization "
301
302
303
304
305
306
307
308
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
309
            if is_hip(
310
            ) and self.quantization not in rocm_supported_quantization:
311
                raise ValueError(
312
313
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm.")
314
            if current_platform.is_tpu(
315
316
317
318
            ) and self.quantization not in tpu_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in TPU Backend.")
319
            if self.quantization not in optimized_quantization_methods:
320
                logger.warning(
321
                    "%s quantization is not fully "
322
                    "optimized yet. The speed can be slower than "
323
                    "non-quantized models.", self.quantization)
324

325
    def _verify_cuda_graph(self) -> None:
326
327
328
329
        if self.max_seq_len_to_capture is None:
            self.max_seq_len_to_capture = self.max_model_len
        self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
                                          self.max_model_len)
330

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    def verify_async_output_proc(self, parallel_config, speculative_config,
                                 device_config) -> None:
        if not self.use_async_output_proc:
            # Nothing to check
            return

        if parallel_config.pipeline_parallel_size > 1:
            logger.warning("Async output processing can not be enabled "
                           "with pipeline parallel")
            self.use_async_output_proc = False
            return

        if device_config.device_type != "cuda":
            logger.warning(
                "Async output processing is only supported for CUDA."
                " Disabling it for other platforms.")
            self.use_async_output_proc = False
            return

        if envs.VLLM_USE_RAY_SPMD_WORKER:
            logger.warning(
                "Async output processing can not be enabled with ray spmd")
            self.use_async_output_proc = False
            return

        if self.enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            self.use_async_output_proc = not self.enforce_eager
            return

        # Async postprocessor is not necessary with embedding mode
        # since there is no token generation
        if self.embedding_mode:
            self.use_async_output_proc = False

        if speculative_config:
            logger.warning("Async output processing is not supported with"
                           " speculative decoding currently.")
            self.use_async_output_proc = False

374
375
376
377
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
378
379
        total_num_attention_heads = getattr(self.hf_text_config,
                                            "num_attention_heads", 0)
380
381
382
383
384
385
386
387
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        pipeline_parallel_size = parallel_config.pipeline_parallel_size
388
389
390
391
392
393
394
        architectures = getattr(self.hf_config, "architectures", [])
        if not all(arch in _PP_SUPPORTED_MODELS
                   for arch in architectures) and pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported for the following "
                f" architectures: {_PP_SUPPORTED_MODELS}.")

395
396
397
398
399
400
        if self.quantization == "bitsandbytes" and (
                parallel_config.tensor_parallel_size > 1
                or parallel_config.pipeline_parallel_size > 1):
            raise ValueError(
                "BitAndBytes quantization with TP or PP is not supported yet.")

401
        if self.quantization == "bitsandbytes" and self.enforce_eager is False:
402
403
404
            logger.warning("CUDA graph is not supported on BitAndBytes yet, "
                           "fallback to the eager mode.")
            self.enforce_eager = True
405

406
407
408
409
410
        if pipeline_parallel_size > 1 and self.use_async_output_proc:
            logger.warning("Async output processor is not supported with "
                           "pipeline parallelism currently. Disabling it.")
            self.use_async_output_proc = False

411
    def get_hf_config_sliding_window(self) -> Optional[int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
412
        """Get the sliding window size, or None if disabled."""
413
414
415
416

        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
        # addition to sliding window size. We check if that field is present
        # and if it's False, return None.
417
418
        if (hasattr(self.hf_text_config, "use_sliding_window")
                and not self.hf_text_config.use_sliding_window):
419
            return None
420
        return getattr(self.hf_text_config, "sliding_window", None)
421

422
423
424
425
426
427
428
429
430
    def get_sliding_window(self) -> Optional[int]:
        """Get the sliding window size, or None if disabled.
        """
        # If user disables sliding window, return None.
        if self.disable_sliding_window:
            return None
        # Otherwise get the value from the hf config.
        return self.get_hf_config_sliding_window()

431
    def get_vocab_size(self) -> int:
432
        return self.hf_text_config.vocab_size
433

434
    def get_hidden_size(self) -> int:
435
        return self.hf_text_config.hidden_size
436
437

    def get_head_size(self) -> int:
wangding zeng's avatar
wangding zeng committed
438
439
440
441
442
443
        # TODO remove hard code
        if hasattr(self.hf_text_config, "model_type"
                   ) and self.hf_text_config.model_type == 'deepseek_v2':
            # FlashAttention supports only head_size 32, 64, 128, 256,
            # we need to pad head_size 192 to 256
            return 256
444
445
        if hasattr(self.hf_text_config, "head_dim"):
            return self.hf_text_config.head_dim
446
        # FIXME(woosuk): This may not be true for all models.
447
448
        return (self.hf_text_config.hidden_size //
                self.hf_text_config.num_attention_heads)
449

450
451
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
452
        # For GPTBigCode & Falcon:
453
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
454
455
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
456
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
457
        new_decoder_arch_falcon = (
458
            self.hf_config.model_type in falcon_model_types
459
            and getattr(self.hf_config, "new_decoder_architecture", False))
460
        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
461
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
462
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
463
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
464
            return 1
465

466
        # For DBRX and MPT
467
468
469
470
471
        if self.hf_config.model_type == "mpt":
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type == "dbrx":
472
473
474
            return getattr(self.hf_config.attn_config, "kv_n_heads",
                           self.hf_config.num_attention_heads)

475
476
477
478
479
480
481
482
483
484
        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
485
            num_kv_heads = getattr(self.hf_text_config, attr, None)
486
487
488
489
490
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
491
        return self.hf_text_config.num_attention_heads
492
493
494
495
496
497
498
499
500
501

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
502

503
504
    def get_num_attention_heads(self,
                                parallel_config: "ParallelConfig") -> int:
505
506
        num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
        return num_heads // parallel_config.tensor_parallel_size
507

508
    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
509
        from vllm.distributed.utils import get_pp_indices
Mor Zusman's avatar
Mor Zusman committed
510
511
        total_num_hidden_layers = getattr(self.hf_text_config,
                                          "num_hidden_layers", 0)
512
513
514
515
        pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
        pp_size = parallel_config.pipeline_parallel_size
        start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
        return end - start
516

Mor Zusman's avatar
Mor Zusman committed
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    def contains_seqlen_agnostic_layers(
            self, parallel_config: "ParallelConfig") -> bool:
        """True for Mamba/SSM models (Jamba)"""
        return self._get_num_seqlen_agnostic_layers(parallel_config) > 0

    def get_layers_block_type(self,
                              parallel_config: "ParallelConfig") -> List[str]:
        num_layers = self.get_num_layers(parallel_config)
        # Transformers supports layers_block_type @property
        return getattr(self.hf_config, "layers_block_type",
                       ["attention"] * num_layers)

    def get_num_attention_layers(self,
                                 parallel_config: "ParallelConfig") -> int:
        return len([
            t for t in self.get_layers_block_type(parallel_config)
            if t == "attention"
        ])

    def _get_num_seqlen_agnostic_layers(
            self, parallel_config: "ParallelConfig") -> int:
        return len([
            t for t in self.get_layers_block_type(parallel_config)
            if t != "attention"
        ])

543
544
545
546
547
548
549
550
551
552
553
554
    def get_multimodal_config(self) -> "MultiModalConfig":
        """
        Get the multimodal configuration of the model.

        Raises:
            ValueError: If the model is not multimodal.
        """
        if self.multimodal_config is None:
            raise ValueError("The model is not multimodal.")

        return self.multimodal_config

555
556
557
558
559
560
561
562
563
564
    @property
    def is_encoder_decoder_model(self) -> bool:
        """Extract the HF encoder/decoder model flag."""
        return getattr(self.hf_config, "is_encoder_decoder", False)

    @property
    def is_embedding_model(self) -> bool:
        """Extract the embedding model flag."""
        return self.embedding_mode

565
566

class CacheConfig:
567
568
569
570
571
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
572
            vLLM execution.
573
        swap_space: Size of the CPU swap space per GPU (in GiB).
574
        cache_dtype: Data type for kv cache storage.
575
        num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
576
            profiled num_gpu_blocks if specified. Does nothing if None.
577
    """
578

579
580
581
582
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
583
        swap_space: float,
584
        cache_dtype: str,
585
        num_gpu_blocks_override: Optional[int] = None,
586
        sliding_window: Optional[int] = None,
587
        enable_prefix_caching: bool = False,
588
        cpu_offload_gb: float = 0,
589
590
591
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
592
        self.swap_space_bytes = swap_space * GiB_bytes
593
        self.num_gpu_blocks_override = num_gpu_blocks_override
594
        self.cache_dtype = cache_dtype
595
        self.sliding_window = sliding_window
596
        self.enable_prefix_caching = enable_prefix_caching
597
        self.cpu_offload_gb = cpu_offload_gb
598
        self._verify_args()
599
        self._verify_cache_dtype()
600
        self._verify_prefix_caching()
601
602
603
604
605

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

606
    def metrics_info(self):
607
608
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
609
610
        return {key: str(value) for key, value in self.__dict__.items()}

611
612
613
614
615
616
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

617
618
619
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
620
        elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
621
            logger.info(
622
623
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
624
625
                "Meanwhile, it may cause accuracy drop without a proper "
                "scaling factor")
626
627
628
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

629
630
631
632
633
634
635
636
637
    def _verify_prefix_caching(self) -> None:
        if not self.enable_prefix_caching:
            return

        if self.sliding_window is not None:
            raise NotImplementedError(
                "Prefix caching is not supported with sliding window. "
                "Run with --disable-sliding-window to use prefix caching.")

638
639
640
641
642
643
644
645
646
647
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

648
649
650
        msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
               f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
               "is allocated for the swap space.")
651
652
653
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
654
            logger.warning("Possibly too large swap space. %s", msg)
655

656

657
658
659
@dataclass
class TokenizerPoolConfig:
    """Configuration for the tokenizer pool.
660

661
662
663
664
665
666
667
668
    Args:
        pool_size: Number of tokenizer workers in the pool.
        pool_type: Type of the pool.
        extra_config: Additional config for the pool.
            The way the config will be used depends on the
            pool type.
    """
    pool_size: int
669
    pool_type: Union[str, Type["BaseTokenizerGroup"]]
670
671
672
    extra_config: dict

    def __post_init__(self):
673
674
        if self.pool_type not in ("ray", ) and not isinstance(
                self.pool_type, type):
675
676
677
678
679
680
681
682
683
684
            raise ValueError(f"Unknown pool type: {self.pool_type}")
        if not isinstance(self.extra_config, dict):
            raise ValueError("extra_config must be a dictionary.")

    @classmethod
    def create_config(
        cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
        tokenizer_pool_extra_config: Optional[Union[str, dict]]
    ) -> Optional["TokenizerPoolConfig"]:
        """Create a TokenizerPoolConfig from the given parameters.
685

686
        If tokenizer_pool_size is 0, return None.
687

688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
        Args:
            tokenizer_pool_size: Number of tokenizer workers in the pool.
            tokenizer_pool_type: Type of the pool.
            tokenizer_pool_extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type. This can be a JSON string (will be parsed).
        """
        if tokenizer_pool_size:
            if isinstance(tokenizer_pool_extra_config, str):
                tokenizer_pool_extra_config_parsed = json.loads(
                    tokenizer_pool_extra_config)
            else:
                tokenizer_pool_extra_config_parsed = (
                    tokenizer_pool_extra_config or {})
            tokenizer_pool_config = cls(tokenizer_pool_size,
                                        tokenizer_pool_type,
                                        tokenizer_pool_extra_config_parsed)
        else:
            tokenizer_pool_config = None
        return tokenizer_pool_config


710
711
712
713
714
715
716
class LoadFormat(str, enum.Enum):
    AUTO = "auto"
    PT = "pt"
    SAFETENSORS = "safetensors"
    NPCACHE = "npcache"
    DUMMY = "dummy"
    TENSORIZER = "tensorizer"
717
    SHARDED_STATE = "sharded_state"
718
    GGUF = "gguf"
719
    BITSANDBYTES = "bitsandbytes"
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738


@dataclass
class LoadConfig:
    """
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
            "tensorizer" will use CoreWeave's tensorizer library for
                fast weight loading.
739
            "bitsandbytes" will load nf4 type weights.
740
741
742
        ignore_patterns: The list of patterns to ignore when loading the model.
            Default to "original/**/*" to avoid repeated loading of llama's 
            checkpoints.
743
            
744
745
746
747
748
749
    """

    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(
        default_factory=dict)
750
    ignore_patterns: Optional[Union[List[str], str]] = None
751
752
753
754
755
756
757
758

    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = json.loads(
                model_loader_extra_config)
        self._verify_load_format()

759
760
761
762
763
764
765
        if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
            logger.info(
                "Ignoring the following patterns when downloading weights: %s",
                self.ignore_patterns)
        else:
            self.ignore_patterns = ["original/**/*"]

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
    def _verify_load_format(self) -> None:
        if not isinstance(self.load_format, str):
            return

        load_format = self.load_format.lower()
        self.load_format = LoadFormat(load_format)

        rocm_not_supported_load_format: List[str] = []
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f for f in LoadFormat.__members__
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format '{load_format}' is not supported in ROCm. "
                f"Supported load formats are "
                f"{rocm_supported_load_format}")


785
class ParallelConfig:
786
787
788
789
790
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
791
        worker_use_ray: Deprecated, use distributed_executor_backend instead.
zspo's avatar
zspo committed
792
793
794
        max_parallel_loading_workers: Maximum number of multiple batches
            when load model sequentially. To avoid RAM OOM when using tensor
            parallel and large models.
795
796
        disable_custom_all_reduce: Disable the custom all-reduce kernel and
            fall back to NCCL.
797
798
        tokenizer_pool_config: Config for the tokenizer pool.
            If None, will use synchronous tokenization.
799
800
        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
801
        placement_group: ray distributed model workers placement group.
802
803
804
805
        distributed_executor_backend: Backend to use for distributed model
            workers, either "ray" or "mp" (multiprocessing). If either
            pipeline_parallel_size or tensor_parallel_size is greater than 1,
            will default to "ray" if Ray is installed or "mp" otherwise.
806
    """
807

808
809
810
811
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
812
        worker_use_ray: Optional[bool] = None,
813
        max_parallel_loading_workers: Optional[int] = None,
814
        disable_custom_all_reduce: bool = False,
815
        tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
816
        ray_workers_use_nsight: bool = False,
817
        placement_group: Optional["PlacementGroup"] = None,
818
819
        distributed_executor_backend: Optional[Union[
            str, Type["ExecutorBase"]]] = None,
820
821
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
822
        self.tensor_parallel_size = tensor_parallel_size
823
        self.distributed_executor_backend = distributed_executor_backend
824
        self.max_parallel_loading_workers = max_parallel_loading_workers
825
        self.disable_custom_all_reduce = disable_custom_all_reduce
826
        self.tokenizer_pool_config = tokenizer_pool_config
827
        self.ray_workers_use_nsight = ray_workers_use_nsight
828
        self.placement_group = placement_group
829
        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
830

831
832
833
        if worker_use_ray:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
834
            elif not self.use_ray:
835
836
837
838
839
                raise ValueError(f"worker-use-ray can't be used with "
                                 f"distributed executor backend "
                                 f"'{self.distributed_executor_backend}'.")

        if self.distributed_executor_backend is None and self.world_size > 1:
840
841
842
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

843
            from vllm.executor import ray_utils
844
            backend = "mp"
845
            ray_found = ray_utils.ray_is_available()
846
            if cuda_device_count_stateless() < self.world_size:
847
848
                if not ray_found:
                    raise ValueError("Unable to load Ray which is "
849
850
851
                                     "required for multi-node inference, "
                                     "please install Ray with `pip install "
                                     "ray`.") from ray_utils.ray_import_err
852
853
                backend = "ray"
            elif ray_found:
854
                if self.placement_group:
855
                    backend = "ray"
856
857
858
859
860
861
                else:
                    from ray import is_initialized as ray_is_initialized
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
                        if get_current_placement_group():
                            backend = "ray"
862
863
864
            self.distributed_executor_backend = backend
            logger.info("Defaulting to use %s for distributed inference",
                        backend)
865

866
        self._verify_args()
867
        self.rank: int = 0
868

869
870
871
872
873
874
    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
            and self.distributed_executor_backend.uses_ray)

875
    def _verify_args(self) -> None:
876
877
878
879
880
881
882
        # Lazy import to avoid circular import
        from vllm.executor.executor_base import ExecutorBase

        if self.distributed_executor_backend not in (
                "ray", "mp", None) and not (isinstance(
                    self.distributed_executor_backend, type) and issubclass(
                        self.distributed_executor_backend, ExecutorBase)):
883
            raise ValueError(
884
885
886
887
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
                "values are 'ray', 'mp' or custom ExecutorBase subclass.")
        if self.use_ray:
888
889
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()
890
891
892
893
894
        if is_hip():
            self.disable_custom_all_reduce = True
            logger.info(
                "Disabled the custom all-reduce kernel because it is not "
                "supported on AMD GPUs.")
895
        if self.ray_workers_use_nsight and not self.use_ray:
896
897
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
898

899
900

class SchedulerConfig:
901
902
903
904
905
906
907
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
908
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
909
            and generated text).
910
911
912
913
914
        use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
        num_lookahead_slots: The number of slots to allocate per sequence per
            step, beyond the known token ids. This is used in speculative
            decoding to store KV activations of tokens which may or may not be
            accepted.
915
916
        delay_factor: Apply a delay (of delay factor multiplied by previous
            prompt latency) before scheduling next prompt.
917
918
        enable_chunked_prefill: If True, prefill requests can be chunked based
            on the remaining max_num_batched_tokens.
919
        embedding_mode: Whether the running model is for embedding.
920
921
922
923
924
925
        preemption_mode: Whether to perform preemption by swapping or 
            recomputation. If not specified, we determine the mode as follows:
            We use recomputation by default since it incurs lower overhead than
            swapping. However, when the sequence group has multiple sequences
            (e.g., beam search), recomputation is not currently supported. In
            such a case, we use swapping instead.
926
927
928
929
930
        send_delta_data: Private API. If used, scheduler sends delta data to
            workers instead of an entire data. It should be enabled only
            when SPMD worker architecture is enabled. I.e.,
            VLLM_USE_RAY_SPMD_WORKER=1

931
    """
932

933
934
935
936
937
938
939
940
941
    def __init__(self,
                 max_num_batched_tokens: Optional[int],
                 max_num_seqs: int,
                 max_model_len: int,
                 use_v2_block_manager: bool = False,
                 num_lookahead_slots: int = 0,
                 delay_factor: float = 0.0,
                 enable_chunked_prefill: bool = False,
                 embedding_mode: Optional[bool] = False,
942
                 preemption_mode: Optional[str] = None,
943
944
                 num_scheduler_steps: int = 1,
                 send_delta_data: bool = False) -> None:
945
946
947
        if max_num_batched_tokens is not None:
            self.max_num_batched_tokens = max_num_batched_tokens
        else:
948
            if enable_chunked_prefill:
949
950
951
                # It is the values that have the best balance between ITL
                # and TTFT on A100. Note it is not optimized for throughput.
                self.max_num_batched_tokens = 512
952
953
954
955
            elif embedding_mode:
                # For embedding, choose specific value for higher throughput
                self.max_num_batched_tokens = max(
                    max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
956
957
958
959
960
            else:
                # If max_model_len is too short, use 2048 as the default value
                # for higher throughput.
                self.max_num_batched_tokens = max(max_model_len, 2048)
        if enable_chunked_prefill:
961
962
            logger.info(
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
963
                self.max_num_batched_tokens)
964

965
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
966
        self.max_model_len = max_model_len
967
        self.use_v2_block_manager = use_v2_block_manager
968
969
        self.num_lookahead_slots = num_lookahead_slots
        self.delay_factor = delay_factor
970
        self.chunked_prefill_enabled = enable_chunked_prefill
971
        self.embedding_mode = embedding_mode
972
        self.preemption_mode = preemption_mode
973
        self.num_scheduler_steps = num_scheduler_steps
974
        self.send_delta_data = send_delta_data
975
976
977
        self._verify_args()

    def _verify_args(self) -> None:
978
979
        if (self.max_num_batched_tokens < self.max_model_len
                and not self.chunked_prefill_enabled):
980
981
982
983
984
985
986
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
987

988
989
990
991
992
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
993

994
995
996
997
998
999
        if self.num_lookahead_slots < 0:
            raise ValueError(
                "num_lookahead_slots "
                f"({self.num_lookahead_slots}) must be greater than or "
                "equal to 0.")

1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
        if self.num_scheduler_steps < 1:
            raise ValueError(
                "num_scheduler_steps "
                f"({self.num_scheduler_steps}) must be greater than or "
                "equal to 1.")

    @property
    def is_multi_step(self) -> bool:
        return self.num_scheduler_steps > 1

1010

1011
class DeviceConfig:
1012
    device: Optional[torch.device]
1013

1014
1015
1016
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
1017
            if is_neuron():
1018
                self.device_type = "neuron"
1019
1020
            elif is_openvino():
                self.device_type = "openvino"
1021
            elif current_platform.is_tpu():
1022
                self.device_type = "tpu"
1023
1024
            elif is_cpu():
                self.device_type = "cpu"
1025
1026
            elif is_xpu():
                self.device_type = "xpu"
1027
            else:
1028
1029
1030
                # We don't call torch.cuda.is_available() here to
                # avoid initializing CUDA before workers are forked
                self.device_type = "cuda"
1031
1032
1033
1034
1035
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
1036
        if self.device_type in ["neuron", "openvino"]:
1037
            self.device = torch.device("cpu")
1038
1039
        elif self.device_type in ["tpu"]:
            self.device = None
1040
1041
1042
1043
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

1044

1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
class SpeculativeConfig:
    """Configuration for speculative decoding.

    The configuration is currently specialized to draft-model speculative
    decoding with top-1 proposals.
    """

    @staticmethod
    def maybe_create_spec_config(
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        target_dtype: str,
        speculative_model: Optional[str],
1058
        speculative_model_quantization: Optional[str],
1059
        speculative_draft_tensor_parallel_size: Optional[int],
1060
        num_speculative_tokens: Optional[int],
1061
1062
1063
        speculative_max_model_len: Optional[int],
        enable_chunked_prefill: bool,
        use_v2_block_manager: bool,
1064
        disable_log_stats: bool,
1065
        speculative_disable_by_batch_size: Optional[int],
1066
1067
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1068
1069
1070
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: Optional[float],
        typical_acceptance_sampler_posterior_alpha: Optional[float],
1071
        disable_logprobs: Optional[bool],
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
    ) -> Optional["SpeculativeConfig"]:
        """Create a SpeculativeConfig if possible, else return None.

        This function attempts to create a SpeculativeConfig object based on the
        provided parameters. If the necessary conditions are met, it returns an
        instance of SpeculativeConfig. Otherwise, it returns None.

        Args:
            target_model_config (ModelConfig): The configuration of the target
                model.
            target_parallel_config (ParallelConfig): The parallel configuration
                for the target model.
            target_dtype (str): The data type used for the target model.
            speculative_model (Optional[str]): The name of the speculative
                model, if provided.
1087
1088
1089
            speculative_model_quantization (Optional[str]): Quantization method
                that was used to quantize the speculative model weights. If
                None, we assume the model weights are not quantized.
1090
1091
            speculative_draft_tensor_parallel_size (Optional[int]): The degree
                of the tensor parallelism for the draft model.
1092
            num_speculative_tokens (Optional[int]): The number of speculative
1093
1094
                tokens, if provided. Will default to the number in the draft
                model config if present, otherwise is required.
1095
1096
1097
1098
1099
1100
1101
1102
1103
            speculative_max_model_len (Optional[int]): The maximum model len of
                the speculative model. Used when testing the ability to skip
                speculation for some sequences.
            enable_chunked_prefill (bool): Whether vLLM is configured to use
                chunked prefill or not. Used for raising an error since its not
                yet compatible with spec decode.
            use_v2_block_manager (bool): Whether vLLM is configured to use the
                v2 block manager or not. Used for raising an error since the v2
                block manager is required with spec decode.
1104
1105
1106
            speculative_disable_by_batch_size (Optional[int]): Disable
                speculative decoding for new incoming requests when the number
                of enqueue requests  is larger than this value, if provided.
1107
1108
1109
1110
            ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
                window, if provided.
            ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
                window, if provided.
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
                accepted. This threshold is used only when we use the 
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1124
1125
1126
1127
1128
            disable_logprobs (Optional[bool]): If set to True, token log
                probabilities are not returned during speculative decoding.
                If set to False, token log probabilities are returned
                according to the log probability settings in SamplingParams.
                If not specified, it defaults to True.
1129
    
1130
1131
1132
1133
1134
        Returns:
            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                the necessary conditions are met, else None.
        """

1135
1136
1137
1138
        if speculative_model is None:
            if num_speculative_tokens is not None:
                raise ValueError("num_speculative_tokens was provided without "
                                 "speculative_model.")
1139
1140
            return None

1141
1142
1143
1144
1145
1146
        if (speculative_disable_by_batch_size is not None
                and speculative_disable_by_batch_size < 2):
            raise ValueError("Expect the batch size threshold of disabling "
                             "speculative decoding is > 1, but got "
                             f"{speculative_disable_by_batch_size=}")

1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
        if enable_chunked_prefill:
            raise ValueError(
                "Speculative decoding and chunked prefill are "
                f"currently mutually exclusive ({enable_chunked_prefill=}).")

        if not use_v2_block_manager:
            raise ValueError(
                "Speculative decoding requires usage of the V2 "
                "block manager. Enable it with --use-v2-block-manager.")

1157
1158
        # TODO: The user should be able to specify revision/max model len
        # for the draft model. It is not currently supported.
1159
1160
        draft_revision = None
        draft_code_revision = None
1161
        draft_quantization = speculative_model_quantization
1162

1163
1164
        if speculative_model == "[ngram]":
            if ngram_prompt_lookup_min is None:
1165
1166
1167
1168
1169
1170
1171
1172
                ngram_prompt_lookup_min = 1
            if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
                raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
            if ngram_prompt_lookup_min < 1:
                raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
            if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
                raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
                                 f"larger than {ngram_prompt_lookup_max=}")
1173

1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            draft_model_config = target_model_config
            draft_parallel_config = target_parallel_config
        else:
            ngram_prompt_lookup_max = 0
            ngram_prompt_lookup_min = 0
            draft_model_config = ModelConfig(
                model=speculative_model,
                tokenizer=target_model_config.tokenizer,
                tokenizer_mode=target_model_config.tokenizer_mode,
                trust_remote_code=target_model_config.trust_remote_code,
                dtype=target_model_config.dtype,
                seed=target_model_config.seed,
                revision=draft_revision,
                code_revision=draft_code_revision,
                tokenizer_revision=target_model_config.tokenizer_revision,
                max_model_len=None,
1193
                spec_target_max_model_len=target_model_config.max_model_len,
1194
1195
                quantization=draft_quantization,
                enforce_eager=target_model_config.enforce_eager,
1196
1197
                max_seq_len_to_capture=target_model_config.
                max_seq_len_to_capture,
1198
1199
1200
                max_logprobs=target_model_config.max_logprobs,
            )

1201
            draft_hf_config = draft_model_config.hf_config
1202

1203
1204
1205
1206
1207
            if (num_speculative_tokens is not None
                    and hasattr(draft_hf_config, "num_lookahead_tokens")):
                draft_hf_config.num_lookahead_tokens = num_speculative_tokens

            n_predict = getattr(draft_hf_config, "n_predict", None)
1208
1209
1210
1211
1212
1213
1214
1215
            if n_predict is not None:
                if num_speculative_tokens is None:
                    # Default to max value defined in draft model config.
                    num_speculative_tokens = n_predict
                elif num_speculative_tokens > n_predict:
                    # Verify provided value doesn't exceed the maximum
                    # supported by the draft model.
                    raise ValueError(
1216
1217
1218
                        "This speculative model supports a maximum of "
                        f"num_speculative_tokens={n_predict}, but "
                        f"{num_speculative_tokens=} was provided.")
1219

1220
1221
1222
1223
1224
1225
1226
1227
1228
            draft_model_config.max_model_len = (
                SpeculativeConfig._maybe_override_draft_max_model_len(
                    speculative_max_model_len,
                    draft_model_config.max_model_len,
                    target_model_config.max_model_len,
                ))

            draft_parallel_config = (
                SpeculativeConfig.create_draft_parallel_config(
1229
                    target_parallel_config,
1230
                    speculative_draft_tensor_parallel_size, draft_hf_config))
1231

1232
1233
1234
1235
1236
1237
        if num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative_model unless the draft model config contains an "
                "n_predict parameter.")

1238
1239
1240
1241
        if typical_acceptance_sampler_posterior_threshold is None:
            typical_acceptance_sampler_posterior_threshold = 0.09
        if typical_acceptance_sampler_posterior_alpha is None:
            typical_acceptance_sampler_posterior_alpha = 0.3
1242
1243
        if disable_logprobs is None:
            disable_logprobs = True
1244

1245
1246
1247
1248
        return SpeculativeConfig(
            draft_model_config,
            draft_parallel_config,
            num_speculative_tokens,
1249
            speculative_disable_by_batch_size,
1250
1251
            ngram_prompt_lookup_max,
            ngram_prompt_lookup_min,
1252
1253
1254
1255
1256
            draft_token_acceptance_method=draft_token_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=\
                typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=\
                typical_acceptance_sampler_posterior_alpha,
1257
1258
            disable_logprobs=disable_logprobs,
            disable_log_stats=disable_log_stats,
1259
1260
        )

1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
    @staticmethod
    def _maybe_override_draft_max_model_len(
        speculative_max_model_len: Optional[int],
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:

            if speculative_max_model_len > draft_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {draft_max_model_len=}")

            if speculative_max_model_len > target_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {target_max_model_len=}")

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

1296
1297
    @staticmethod
    def create_draft_parallel_config(
1298
        target_parallel_config: ParallelConfig,
1299
1300
        speculative_draft_tensor_parallel_size: Optional[int],
        draft_hf_config: PretrainedConfig,
1301
    ) -> ParallelConfig:
1302
1303
        """Create a parallel config for use by the draft worker.

1304
        This is mostly a copy of the target parallel config, except the tp_size.
1305
        """
1306
        if speculative_draft_tensor_parallel_size is None:
1307
1308
1309
1310
1311
1312
1313
1314
1315
            if draft_hf_config.model_type == "mlp_speculator":
                speculative_draft_tensor_parallel_size = 1
                if target_parallel_config.tensor_parallel_size > 1:
                    logger.warning(
                        "MLPSpeculator cannot currently be run with tp>1; "
                        "setting speculative_draft_tensor_parallel_size=1")
            else:
                speculative_draft_tensor_parallel_size = \
                    target_parallel_config.tensor_parallel_size
1316
1317
1318
        elif speculative_draft_tensor_parallel_size != 1:
            # TODO(wooyeon): allow tp values larger than 1
            raise ValueError(
1319
                f"{speculative_draft_tensor_parallel_size=} cannot be "
1320
1321
                f"other value than 1")

1322
1323
1324
        draft_parallel_config = ParallelConfig(
            pipeline_parallel_size=target_parallel_config.
            pipeline_parallel_size,
1325
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
1326
1327
            distributed_executor_backend=target_parallel_config.
            distributed_executor_backend,
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
            max_parallel_loading_workers=target_parallel_config.
            max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.
            disable_custom_all_reduce,
            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
            ray_workers_use_nsight=target_parallel_config.
            ray_workers_use_nsight,
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

    def __init__(
        self,
        draft_model_config: ModelConfig,
        draft_parallel_config: ParallelConfig,
        num_speculative_tokens: int,
1345
1346
1347
        speculative_disable_by_batch_size: Optional[int],
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1348
1349
1350
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
1351
        disable_logprobs: bool,
1352
        disable_log_stats: bool,
1353
1354
1355
1356
1357
1358
1359
1360
    ):
        """Create a SpeculativeConfig object.

        Args:
            draft_model_config: ModelConfig for the draft model.
            draft_parallel_config: ParallelConfig for the draft model.
            num_speculative_tokens: The number of tokens to sample from the
                draft model before scoring with the target model.
1361
1362
1363
1364
1365
            speculative_disable_by_batch_size: Disable speculative
                decoding for new incoming requests when the number of
                enqueue requests is larger than this value.
            ngram_prompt_lookup_max: Max size of ngram token window.
            ngram_prompt_lookup_min: Min size of ngram token window.
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
                accepted. This threshold is used only when we use the 
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1379
1380
1381
1382
1383
1384
            disable_logprobs: If set to True, token log probabilities will not
                be returned even if requested by sampling parameters. This 
                reduces latency by skipping logprob calculation in proposal
                sampling, target sampling, and after accepted tokens are
                determined. If set to False, log probabilities will be
                returned.
1385
1386
            disable_log_stats: Whether to disable periodic printing of stage
                times in speculative decoding.
1387
1388
1389
1390
        """
        self.draft_model_config = draft_model_config
        self.draft_parallel_config = draft_parallel_config
        self.num_speculative_tokens = num_speculative_tokens
1391
1392
1393
1394
        self.speculative_disable_by_batch_size = \
            speculative_disable_by_batch_size
        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
1395
1396
1397
1398
1399
        self.draft_token_acceptance_method = draft_token_acceptance_method
        self.typical_acceptance_sampler_posterior_threshold = \
            typical_acceptance_sampler_posterior_threshold
        self.typical_acceptance_sampler_posterior_alpha = \
            typical_acceptance_sampler_posterior_alpha
1400
        self.disable_logprobs = disable_logprobs
1401
        self.disable_log_stats = disable_log_stats
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412

        self._verify_args()

    def _verify_args(self) -> None:
        if self.num_speculative_tokens <= 0:
            raise ValueError("Expected num_speculative_tokens to be greater "
                             f"than zero ({self.num_speculative_tokens}).")

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
                self.draft_parallel_config)
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
            # Validate and set draft token acceptance related settings.

        if (self.draft_token_acceptance_method is None):
            raise ValueError("draft_token_acceptance_method is not set. "
                             "Expected values are rejection_sampler or "
                             "typical_acceptance_sampler.")

        if (self.draft_token_acceptance_method != 'rejection_sampler'
                and self.draft_token_acceptance_method !=
                'typical_acceptance_sampler'):
            raise ValueError(
                "Expected draft_token_acceptance_method to be either "
                "rejection_sampler or typical_acceptance_sampler. Instead it "
                f"is {self.draft_token_acceptance_method}")

        if (self.typical_acceptance_sampler_posterior_threshold < 0
                or self.typical_acceptance_sampler_posterior_alpha < 0):
            raise ValueError(
                "Expected typical_acceptance_sampler_posterior_threshold "
                "and typical_acceptance_sampler_posterior_alpha to be > 0. "
                "Instead found "
                f"typical_acceptance_sampler_posterior_threshold = "
                f"{self.typical_acceptance_sampler_posterior_threshold} and "
                f"typical_acceptance_sampler_posterior_alpha = "
                f"{self.typical_acceptance_sampler_posterior_alpha}")
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def __repr__(self) -> str:
1450
1451
1452
1453
        if self.ngram_prompt_lookup_max > 0:
            draft_model = "[ngram]"
        else:
            draft_model = self.draft_model_config.model
1454
1455
1456
1457
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


1458
1459
1460
1461
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
1462
    fully_sharded_loras: bool = False
1463
1464
1465
1466
1467
    max_cpu_loras: Optional[int] = None
    lora_dtype: Optional[torch.dtype] = None
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256
1468
    long_lora_scaling_factors: Optional[Tuple[float]] = None
1469
1470

    def __post_init__(self):
1471
1472
1473
        # Setting the maximum rank to 256 should be able to satisfy the vast
        # majority of applications.
        possible_max_ranks = (8, 16, 32, 64, 128, 256)
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
1490
                f"max_loras ({self.max_loras})")
1491
1492
1493
1494
1495
1496

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
1497
1498
1499
1500
        if model_config.quantization and model_config.quantization not in [
                "awq", "gptq"
        ]:
            # TODO support marlin and squeezellm
1501
1502
            logger.warning("%s quantization is not tested with LoRA yet.",
                           model_config.quantization)
1503
1504

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
1505
1506
        if scheduler_config.chunked_prefill_enabled:
            raise ValueError("LoRA is not supported with chunked prefill yet.")
1507
1508


1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
@dataclass
class PromptAdapterConfig:
    max_prompt_adapters: int
    max_prompt_adapter_token: int
    max_cpu_prompt_adapters: Optional[int] = None
    prompt_adapter_dtype: Optional[torch.dtype] = None

    def __post_init__(self):
        library_name = 'peft'
        try:
            __import__(library_name)
        except ImportError as e:
            raise ImportError(
                f"'{library_name}' is not installed for prompt adapter support."
                f"Please install it using 'pip install {library_name}'."
            ) from e

        if self.max_prompt_adapters < 1:
            raise ValueError(f"max_prompt_adapters "
                             f"({self.max_prompt_adapters}) must be >= 1.")
        if self.max_prompt_adapter_token == 0:
            raise ValueError("max_prompt_adapter_token must be set.")
        if self.max_cpu_prompt_adapters is None:
            self.max_cpu_prompt_adapters = self.max_prompt_adapters

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.prompt_adapter_dtype in (None, "auto"):
            self.prompt_adapter_dtype = model_config.dtype
        elif isinstance(self.prompt_adapter_dtype, str):
            self.prompt_adapter_dtype = getattr(torch,
                                                self.prompt_adapter_dtype)


1542
@dataclass
1543
class MultiModalConfig:
1544
1545
    """Controls the behavior of multimodal models."""

1546
    limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
1547
1548
1549
1550
1551
    """
    The maximum number of multi-modal input instances allowed per prompt
    for each :class:`~vllm.multimodal.MultiModalPlugin`.
    """

1552
    # TODO: Add configs to init vision tower or not.
1553

1554

1555
1556
1557
1558
1559
1560
1561
1562
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

1563
_ROCM_NOT_SUPPORTED_DTYPE: List[str] = []  #
1564

1565
1566
1567

def _get_and_verify_dtype(
    config: PretrainedConfig,
1568
    dtype: Union[str, torch.dtype],
1569
1570
1571
1572
1573
1574
1575
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

1576
1577
1578
1579
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
Woosuk Kwon's avatar
Woosuk Kwon committed
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
                if config.model_type == "gemma2":
                    logger.info(
                        "For Gemma 2, we downcast float32 to bfloat16 instead "
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16.")
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
1590
1591
            else:
                torch_dtype = config_dtype
1592
        else:
1593
1594
1595
1596
1597
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
1598
    else:
1599
        raise ValueError(f"Unknown dtype: {dtype}")
1600
1601
1602
1603
1604

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
1605
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
1606
1607
1608
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
1609
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
1610
1611
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
1612
            # Casting between float16 and bfloat16 is allowed with a warning.
1613
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
1614
1615

    return torch_dtype
1616
1617
1618
1619
1620


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
1621
1622
    disable_sliding_window: bool,
    sliding_window_len: Optional[int],
1623
    spec_target_max_model_len: Optional[int] = None,
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
1634
1635
        # ChatGLM2
        "seq_length",
1636
1637
        # Command-R
        "model_max_length",
1638
1639
1640
1641
1642
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
1643
    # Choose the smallest "max_length" from the possible keys.
1644
    max_len_key = None
1645
    for key in possible_keys:
1646
1647
1648
1649
1650
        max_len = getattr(hf_config, key, None)
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660

    # If sliding window is manually disabled, max_length should be less
    # than the sliding window length in the model config.
    if disable_sliding_window and sliding_window_len is not None:
        max_len_key = "sliding_window" \
            if sliding_window_len < derived_max_model_len else max_len_key
        derived_max_model_len = min(derived_max_model_len, sliding_window_len)

    # If none of the keys were found in the config, use a default and
    # log a warning.
1661
    if derived_max_model_len == float("inf"):
1662
1663
1664
1665
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

1666
1667
1668
1669
1670
        if spec_target_max_model_len is not None:
            # If this is a speculative draft model, we use the max model len
            # from the target model.
            return spec_target_max_model_len

1671
1672
1673
1674
        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
1675
            "%s. Assuming the model's maximum length is %d.", possible_keys,
1676
            default_max_len)
1677
        derived_max_model_len = default_max_len
1678

1679
    rope_scaling = getattr(hf_config, "rope_scaling", None)
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
    if rope_scaling is not None:
        if "type" in rope_scaling:
            rope_type = rope_scaling["type"]
        elif "rope_type" in rope_scaling:
            rope_type = rope_scaling["rope_type"]
        else:
            raise ValueError(
                "rope_scaling must have a 'type' or 'rope_type' key.")

        # The correct one should be "longrope", kept "su" here
        # to be backward compatible
        if rope_type not in ("su", "longrope", "llama3"):
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that supports rope_scaling
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "with rope_scaling. Please raise an issue so we can "
                    "investigate.")

            assert "factor" in rope_scaling
            scaling_factor = rope_scaling["factor"]
            if rope_type == "yarn":
                derived_max_model_len = rope_scaling[
                    "original_max_position_embeddings"]
            derived_max_model_len *= scaling_factor
1706

1707
1708
    # If the user specified a max length, make sure it is smaller than the
    # derived length from the HF model config.
1709
    if max_model_len is None:
1710
        max_model_len = int(derived_max_model_len)
1711
    elif max_model_len > derived_max_model_len:
1712
1713
1714
1715
1716
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, "model_max_length", None)
        if model_max_length is not None and max_model_len <= model_max_length:
1717
1718
1719
1720
1721
1722
1723
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that has model_max_length
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "model_max_length in the config. Please raise an issue "
                    "so we can investigate.")
1724
        else:
1725
            msg = (
1726
                f"User-specified max_model_len ({max_model_len}) is greater "
1727
1728
                f"than the derived max_model_len ({max_len_key}="
                f"{derived_max_model_len} or model_max_length="
1729
                f"{model_max_length} in model's config.json). This may lead "
1730
1731
1732
1733
1734
1735
1736
1737
1738
                "to incorrect model outputs or CUDA errors.")
            if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
                logger.warning(
                    "%s Make sure the value is correct and within the "
                    "model context size.", msg)
            else:
                raise ValueError(
                    f"{msg} To allow overriding this maximum, set "
                    "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
1739
    return int(max_model_len)
1740
1741


1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
def get_served_model_name(model: str,
                          served_model_name: Optional[Union[str, List[str]]]):
    """
    If the input is a non-empty list, the first model_name in 
    `served_model_name` is taken. 
    If the input is a non-empty string, it is used directly. 
    For cases where the input is either an empty string or an 
    empty list, the fallback is to use `self.model`.
    """
    if not served_model_name:
        return model
    if isinstance(served_model_name, list):
        return served_model_name[0]
    return served_model_name


1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
@dataclass
class DecodingConfig:
    """Dataclass which contains the decoding strategy of the engine"""

    # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
    guided_decoding_backend: str = 'outlines'

    def __post_init__(self):
        valid_guided_backends = ['outlines', 'lm-format-enforcer']
        backend = self.guided_decoding_backend
        if backend not in valid_guided_backends:
            raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                             f"must be one of {valid_guided_backends}")


1773
1774
1775
1776
1777
@dataclass
class ObservabilityConfig:
    """Configuration for observability."""
    otlp_traces_endpoint: Optional[str] = None

1778
1779
1780
1781
1782
1783
1784
1785
    # Collecting detailed timing information for each request can be expensive.

    # If set, collects the model forward time for the request.
    collect_model_forward_time: bool = False

    # If set, collects the model execute time for the request.
    collect_model_execute_time: bool = False

1786
    def __post_init__(self):
1787
1788
1789
1790
1791
        if not is_otel_available() and self.otlp_traces_endpoint is not None:
            raise ValueError(
                "OpenTelemetry is not available. Unable to configure "
                "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
                f"installed. Original error:\n{otel_import_error_traceback}")
1792

1793
1794
1795
1796
1797
1798
1799
        if ((self.collect_model_forward_time
             or self.collect_model_execute_time)
                and self.otlp_traces_endpoint is None):
            raise ValueError(
                "collect_model_forward_time or collect_model_execute_time "
                "requires --otlp-traces-endpoint to be set.")

1800

1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
@dataclass(frozen=True)
class EngineConfig:
    """Dataclass which contains all engine-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    model_config: ModelConfig
    cache_config: CacheConfig
    parallel_config: ParallelConfig
    scheduler_config: SchedulerConfig
    device_config: DeviceConfig
1812
    load_config: LoadConfig
1813
1814
    lora_config: Optional[LoRAConfig]
    speculative_config: Optional[SpeculativeConfig]
1815
    decoding_config: Optional[DecodingConfig]
1816
    observability_config: Optional[ObservabilityConfig]
1817
    prompt_adapter_config: Optional[PromptAdapterConfig]
1818
1819
1820
1821

    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
1822
1823
1824
        self.model_config.verify_async_output_proc(self.parallel_config,
                                                   self.speculative_config,
                                                   self.device_config)
1825
1826
1827
1828
1829
1830
1831
        self.model_config.verify_with_parallel_config(self.parallel_config)
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
1832
1833
1834
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
1835
1836
1837
1838
1839
1840

    def to_dict(self):
        """Return the configs as a dictionary, for use in **kwargs.
        """
        return dict(
            (field.name, getattr(self, field.name)) for field in fields(self))