config.py 70.8 KB
Newer Older
1
import enum
2
import json
3
from dataclasses import dataclass, field, fields
4
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
5
6

import torch
7
from transformers import PretrainedConfig
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm.logger import init_logger
10
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
11
from vllm.model_executor.models import ModelRegistry
12
from vllm.tracing import is_otel_installed
13
from vllm.transformers_utils.config import get_config, get_hf_text_config
14
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
15
                        is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
16
                        print_warning_once)
17

18
19
20
if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

21
    from vllm.executor.executor_base import ExecutorBase
22
    from vllm.model_executor.model_loader.loader import BaseModelLoader
23
24
    from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
        BaseTokenizerGroup)
25

26
27
logger = init_logger(__name__)

28
_GB = 1 << 30
29
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
30

31
32
33
_PP_SUPPORTED_MODELS = [
    "AquilaModel",
    "AquilaForCausalLM",
34
    "DeepseekV2ForCausalLM",
35
36
37
38
39
40
    "InternLMForCausalLM",
    "LlamaForCausalLM",
    "LLaMAForCausalLM",
    "MistralForCausalLM",
    "Phi3ForCausalLM",
    "GPT2LMHeadModel",
41
    "MixtralForCausalLM",
42
43
]

44
45

class ModelConfig:
46
47
48
49
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
50
51
            It is also used as the content for `model_name` tag in metrics 
            output when `served_model_name` is not specified. 
52
        tokenizer: Name or path of the huggingface tokenizer to use.
53
54
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
            available, and "slow" will always use the slow tokenizer.
55
56
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
57
58
59
60
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
61
62
63
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
64
        code_revision: The specific revision to use for the model code on
65
            Hugging Face Hub. It can be a branch name, a tag name, or a
66
            commit id. If unspecified, will use the default version.
67
68
69
        rope_scaling: Dictionary containing the scaling configuration for the
            RoPE embeddings. When using this flag, don't update
            `max_position_embeddings` to the expected new maximum.
70
71
72
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
73
74
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
75
76
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
77
78
        quantization_param_path: Path to JSON file containing scaling factors.
            Used to load KV cache scaling factors into the model when KV cache
79
80
            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
            be used to load activation and weight scaling factors when the
81
            model dtype is FP8_E4M3 on ROCm.
82
83
84
85
86
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
87
88
89
90
            to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode
91
92
93
94
        disable_sliding_window: Whether to disable sliding window. If True,
            we will disable the sliding window functionality of the model.
            If the model does not support sliding window, this argument is
            ignored.
95
96
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer.
97
98
99
100
        served_model_name: The model name used in metrics tag `model_name`,
            matches the model name exposed via the APIs. If multiple model 
            names provided, the first name will be used. If not specified, 
            the model name will be the same as `model`.
101
    """
102
103
104
105

    def __init__(
        self,
        model: str,
106
107
        tokenizer: str,
        tokenizer_mode: str,
108
        trust_remote_code: bool,
109
        dtype: Union[str, torch.dtype],
110
        seed: int,
111
        revision: Optional[str] = None,
112
        code_revision: Optional[str] = None,
113
        rope_scaling: Optional[dict] = None,
114
        rope_theta: Optional[float] = None,
115
        tokenizer_revision: Optional[str] = None,
116
        max_model_len: Optional[int] = None,
117
        quantization: Optional[str] = None,
118
        quantization_param_path: Optional[str] = None,
119
120
        enforce_eager: bool = False,
        max_context_len_to_capture: Optional[int] = None,
121
        max_seq_len_to_capture: Optional[int] = None,
122
        max_logprobs: int = 20,
123
        disable_sliding_window: bool = False,
124
        skip_tokenizer_init: bool = False,
125
        served_model_name: Optional[Union[str, List[str]]] = None,
126
        multimodal_config: Optional["MultiModalConfig"] = None,
127
128
    ) -> None:
        self.model = model
129
        self.tokenizer = tokenizer
130
        self.tokenizer_mode = tokenizer_mode
131
        self.trust_remote_code = trust_remote_code
132
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
133
        self.revision = revision
134
        self.code_revision = code_revision
135
        self.rope_scaling = rope_scaling
136
        self.rope_theta = rope_theta
137
138
139
140
141
        # The tokenizer version is consistent with the model version by default.
        if tokenizer_revision is None:
            self.tokenizer_revision = revision
        else:
            self.tokenizer_revision = tokenizer_revision
142
        self.quantization = quantization
143
        self.quantization_param_path = quantization_param_path
144
        self.enforce_eager = enforce_eager
145
        if max_context_len_to_capture is not None:
146
147
            raise ValueError("`max_context_len_to_capture` is deprecated. "
                             "Use `max_seq_len_to_capture` instead.")
148
        self.max_seq_len_to_capture = max_seq_len_to_capture
149
        self.max_logprobs = max_logprobs
150
        self.disable_sliding_window = disable_sliding_window
151
        self.skip_tokenizer_init = skip_tokenizer_init
152

153
        self.hf_config = get_config(self.model, trust_remote_code, revision,
154
                                    code_revision, rope_scaling, rope_theta)
155
156
        self.hf_text_config = get_hf_text_config(self.hf_config)
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
159
160
161
162
163
164
165
166
167

        if (not self.disable_sliding_window
                and self.hf_text_config.model_type == "gemma2"
                and self.hf_text_config.sliding_window is not None):
            print_warning_once(
                "Gemma 2 uses sliding window attention for every odd layer, "
                "which is currently not supported by vLLM. Disabling sliding "
                "window and capping the max length to the sliding window size "
                f"({self.hf_text_config.sliding_window}).")
            self.disable_sliding_window = True

168
169
170
171
172
        self.max_model_len = _get_and_verify_max_len(
            hf_config=self.hf_text_config,
            max_model_len=max_model_len,
            disable_sliding_window=self.disable_sliding_window,
            sliding_window_len=self.get_hf_config_sliding_window())
173
174
        self.served_model_name = get_served_model_name(model,
                                                       served_model_name)
175
176
        self.multimodal_config = multimodal_config

177
178
        if not self.skip_tokenizer_init:
            self._verify_tokenizer_mode()
179
        self._verify_embedding_mode()
180
        self._verify_quantization()
181
        self._verify_cuda_graph()
182
183
184
185
186
187
188
189

    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
        if tokenizer_mode not in ["auto", "slow"]:
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                "either 'auto' or 'slow'.")
        self.tokenizer_mode = tokenizer_mode
190

191
192
193
194
195
    def _verify_embedding_mode(self) -> None:
        architectures = getattr(self.hf_config, "architectures", [])
        self.embedding_mode = any(
            ModelRegistry.is_embedding_model(arch) for arch in architectures)

196
197
198
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
199
200
            # compress-tensors uses a "compression_config" key
            quant_cfg = getattr(self.hf_config, "compression_config", None)
201
202
        return quant_cfg

203
    def _verify_quantization(self) -> None:
204
205
        supported_quantization = [*QUANTIZATION_METHODS]
        rocm_supported_quantization = ["gptq", "squeezellm"]
206
207
208
209
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
210
211
        quant_cfg = self._parse_quant_hf_config()

212
213
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
214
215

            # Detect which checkpoint is it
216
            for _, method in QUANTIZATION_METHODS.items():
217
218
219
220
221
222
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization)
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break
223

224
            # Verify quantization configurations.
225
            if self.quantization is None:
226
227
                self.quantization = quant_method
            elif self.quantization != quant_method:
228
229
                raise ValueError(
                    "Quantization method specified in the model config "
230
                    f"({quant_method}) does not match the quantization "
231
232
233
234
235
236
237
238
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
239
            if is_hip(
240
            ) and self.quantization not in rocm_supported_quantization:
241
                raise ValueError(
242
243
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm.")
244
            if (self.quantization
245
                    not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
246
                            "awq_marlin", "fbgemm_fp8", "compressed_tensors")):
247
                logger.warning(
248
                    "%s quantization is not fully "
249
                    "optimized yet. The speed can be slower than "
250
                    "non-quantized models.", self.quantization)
251

252
    def _verify_cuda_graph(self) -> None:
253
254
255
256
        if self.max_seq_len_to_capture is None:
            self.max_seq_len_to_capture = self.max_model_len
        self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
                                          self.max_model_len)
257

258
259
260
261
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
262
263
        total_num_attention_heads = getattr(self.hf_text_config,
                                            "num_attention_heads", 0)
264
265
266
267
268
269
270
271
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        pipeline_parallel_size = parallel_config.pipeline_parallel_size
272
273
274
275
276
277
278
        architectures = getattr(self.hf_config, "architectures", [])
        if not all(arch in _PP_SUPPORTED_MODELS
                   for arch in architectures) and pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is only supported for the following "
                f" architectures: {_PP_SUPPORTED_MODELS}.")

279
280
281
282
283
284
        if self.quantization == "bitsandbytes" and (
                parallel_config.tensor_parallel_size > 1
                or parallel_config.pipeline_parallel_size > 1):
            raise ValueError(
                "BitAndBytes quantization with TP or PP is not supported yet.")

285
    def get_hf_config_sliding_window(self) -> Optional[int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
286
        """Get the sliding window size, or None if disabled."""
287
288
289
290

        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
        # addition to sliding window size. We check if that field is present
        # and if it's False, return None.
291
292
        if (hasattr(self.hf_text_config, "use_sliding_window")
                and not self.hf_text_config.use_sliding_window):
293
            return None
294
        return getattr(self.hf_text_config, "sliding_window", None)
295

296
297
298
299
300
301
302
303
304
    def get_sliding_window(self) -> Optional[int]:
        """Get the sliding window size, or None if disabled.
        """
        # If user disables sliding window, return None.
        if self.disable_sliding_window:
            return None
        # Otherwise get the value from the hf config.
        return self.get_hf_config_sliding_window()

305
    def get_vocab_size(self) -> int:
306
        return self.hf_text_config.vocab_size
307

308
    def get_hidden_size(self) -> int:
309
        return self.hf_text_config.hidden_size
310
311

    def get_head_size(self) -> int:
wangding zeng's avatar
wangding zeng committed
312
313
314
315
316
317
        # TODO remove hard code
        if hasattr(self.hf_text_config, "model_type"
                   ) and self.hf_text_config.model_type == 'deepseek_v2':
            # FlashAttention supports only head_size 32, 64, 128, 256,
            # we need to pad head_size 192 to 256
            return 256
318
319
        if hasattr(self.hf_text_config, "head_dim"):
            return self.hf_text_config.head_dim
320
        # FIXME(woosuk): This may not be true for all models.
321
322
        return (self.hf_text_config.hidden_size //
                self.hf_text_config.num_attention_heads)
323

324
325
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
326
        # For GPTBigCode & Falcon:
327
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
328
329
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
330
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
331
        new_decoder_arch_falcon = (
332
            self.hf_config.model_type in falcon_model_types
333
            and getattr(self.hf_config, "new_decoder_architecture", False))
334
        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
335
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
336
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
337
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
338
            return 1
339

340
        # For DBRX and MPT
341
342
343
344
345
        if self.hf_config.model_type == "mpt":
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type == "dbrx":
346
347
348
            return getattr(self.hf_config.attn_config, "kv_n_heads",
                           self.hf_config.num_attention_heads)

349
350
351
352
353
354
355
356
357
358
        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
359
            num_kv_heads = getattr(self.hf_text_config, attr, None)
360
361
362
363
364
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
365
        return self.hf_text_config.num_attention_heads
366
367
368
369
370
371
372
373
374
375

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
376

377
378
    def get_num_attention_heads(self,
                                parallel_config: "ParallelConfig") -> int:
379
380
        num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
        return num_heads // parallel_config.tensor_parallel_size
381

382
    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
383
        from vllm.distributed.utils import get_pp_indices
Mor Zusman's avatar
Mor Zusman committed
384
385
        total_num_hidden_layers = getattr(self.hf_text_config,
                                          "num_hidden_layers", 0)
386
387
388
389
        pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
        pp_size = parallel_config.pipeline_parallel_size
        start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
        return end - start
390

Mor Zusman's avatar
Mor Zusman committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    def contains_seqlen_agnostic_layers(
            self, parallel_config: "ParallelConfig") -> bool:
        """True for Mamba/SSM models (Jamba)"""
        return self._get_num_seqlen_agnostic_layers(parallel_config) > 0

    def get_layers_block_type(self,
                              parallel_config: "ParallelConfig") -> List[str]:
        num_layers = self.get_num_layers(parallel_config)
        # Transformers supports layers_block_type @property
        return getattr(self.hf_config, "layers_block_type",
                       ["attention"] * num_layers)

    def get_num_attention_layers(self,
                                 parallel_config: "ParallelConfig") -> int:
        return len([
            t for t in self.get_layers_block_type(parallel_config)
            if t == "attention"
        ])

    def _get_num_seqlen_agnostic_layers(
            self, parallel_config: "ParallelConfig") -> int:
        return len([
            t for t in self.get_layers_block_type(parallel_config)
            if t != "attention"
        ])

417
418

class CacheConfig:
419
420
421
422
423
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
424
            vLLM execution.
425
        swap_space: Size of the CPU swap space per GPU (in GiB).
426
        cache_dtype: Data type for kv cache storage.
427
        num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
428
            profiled num_gpu_blocks if specified. Does nothing if None.
429
    """
430

431
432
433
434
435
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        swap_space: int,
436
        cache_dtype: str,
437
        num_gpu_blocks_override: Optional[int] = None,
438
        sliding_window: Optional[int] = None,
439
        enable_prefix_caching: bool = False,
440
        cpu_offload_gb: float = 0,
441
442
443
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
444
        self.swap_space_bytes = swap_space * _GB
445
        self.num_gpu_blocks_override = num_gpu_blocks_override
446
        self.cache_dtype = cache_dtype
447
        self.sliding_window = sliding_window
448
        self.enable_prefix_caching = enable_prefix_caching
449
        self.cpu_offload_gb = cpu_offload_gb
450
        self._verify_args()
451
        self._verify_cache_dtype()
452
        self._verify_prefix_caching()
453
454
455
456
457

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

458
    def metrics_info(self):
459
460
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
461
462
        return {key: str(value) for key, value in self.__dict__.items()}

463
464
465
466
467
468
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

469
470
471
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
472
        elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
473
            logger.info(
474
475
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
476
477
                "Meanwhile, it may cause accuracy drop without a proper "
                "scaling factor")
478
479
480
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

481
482
483
484
485
486
487
488
489
490
491
492
493
    def _verify_prefix_caching(self) -> None:
        if not self.enable_prefix_caching:
            return

        if self.sliding_window is not None:
            raise NotImplementedError(
                "Prefix caching is not supported with sliding window. "
                "Run with --disable-sliding-window to use prefix caching.")
        if self.cache_dtype == "fp8":
            raise NotImplementedError(
                "Prefix caching is not supported for fp8 cache_dtype. "
                "Run with --kv-cache-dtype auto to use prefix caching.")

494
495
496
497
498
499
500
501
502
503
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

504
505
506
        msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
               f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
               "allocated for the swap space.")
507
508
509
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
510
            logger.warning("Possibly too large swap space. %s", msg)
511

512

513
514
515
@dataclass
class TokenizerPoolConfig:
    """Configuration for the tokenizer pool.
516

517
518
519
520
521
522
523
524
    Args:
        pool_size: Number of tokenizer workers in the pool.
        pool_type: Type of the pool.
        extra_config: Additional config for the pool.
            The way the config will be used depends on the
            pool type.
    """
    pool_size: int
525
    pool_type: Union[str, Type["BaseTokenizerGroup"]]
526
527
528
    extra_config: dict

    def __post_init__(self):
529
530
        if self.pool_type not in ("ray", ) and not isinstance(
                self.pool_type, type):
531
532
533
534
535
536
537
538
539
540
            raise ValueError(f"Unknown pool type: {self.pool_type}")
        if not isinstance(self.extra_config, dict):
            raise ValueError("extra_config must be a dictionary.")

    @classmethod
    def create_config(
        cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
        tokenizer_pool_extra_config: Optional[Union[str, dict]]
    ) -> Optional["TokenizerPoolConfig"]:
        """Create a TokenizerPoolConfig from the given parameters.
541

542
        If tokenizer_pool_size is 0, return None.
543

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        Args:
            tokenizer_pool_size: Number of tokenizer workers in the pool.
            tokenizer_pool_type: Type of the pool.
            tokenizer_pool_extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type. This can be a JSON string (will be parsed).
        """
        if tokenizer_pool_size:
            if isinstance(tokenizer_pool_extra_config, str):
                tokenizer_pool_extra_config_parsed = json.loads(
                    tokenizer_pool_extra_config)
            else:
                tokenizer_pool_extra_config_parsed = (
                    tokenizer_pool_extra_config or {})
            tokenizer_pool_config = cls(tokenizer_pool_size,
                                        tokenizer_pool_type,
                                        tokenizer_pool_extra_config_parsed)
        else:
            tokenizer_pool_config = None
        return tokenizer_pool_config


566
567
568
569
570
571
572
class LoadFormat(str, enum.Enum):
    AUTO = "auto"
    PT = "pt"
    SAFETENSORS = "safetensors"
    NPCACHE = "npcache"
    DUMMY = "dummy"
    TENSORIZER = "tensorizer"
573
    SHARDED_STATE = "sharded_state"
574
    BITSANDBYTES = "bitsandbytes"
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593


@dataclass
class LoadConfig:
    """
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
            "tensorizer" will use CoreWeave's tensorizer library for
                fast weight loading.
594
595
596
        ignore_patterns: The list of patterns to ignore when loading the model.
            Default to "original/**/*" to avoid repeated loading of llama's 
            checkpoints.
597
598
599
600
601
602
    """

    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(
        default_factory=dict)
603
    ignore_patterns: Optional[Union[List[str], str]] = None
604
605
606
607
608
609
610
611

    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = json.loads(
                model_loader_extra_config)
        self._verify_load_format()

612
613
614
615
616
617
618
        if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
            logger.info(
                "Ignoring the following patterns when downloading weights: %s",
                self.ignore_patterns)
        else:
            self.ignore_patterns = ["original/**/*"]

619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
    def _verify_load_format(self) -> None:
        if not isinstance(self.load_format, str):
            return

        load_format = self.load_format.lower()
        self.load_format = LoadFormat(load_format)

        rocm_not_supported_load_format: List[str] = []
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f for f in LoadFormat.__members__
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format '{load_format}' is not supported in ROCm. "
                f"Supported load formats are "
                f"{rocm_supported_load_format}")


638
class ParallelConfig:
639
640
641
642
643
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
644
        worker_use_ray: Deprecated, use distributed_executor_backend instead.
zspo's avatar
zspo committed
645
646
647
        max_parallel_loading_workers: Maximum number of multiple batches
            when load model sequentially. To avoid RAM OOM when using tensor
            parallel and large models.
648
649
        disable_custom_all_reduce: Disable the custom all-reduce kernel and
            fall back to NCCL.
650
651
        tokenizer_pool_config: Config for the tokenizer pool.
            If None, will use synchronous tokenization.
652
653
        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
654
        placement_group: ray distributed model workers placement group.
655
656
657
658
        distributed_executor_backend: Backend to use for distributed model
            workers, either "ray" or "mp" (multiprocessing). If either
            pipeline_parallel_size or tensor_parallel_size is greater than 1,
            will default to "ray" if Ray is installed or "mp" otherwise.
659
    """
660

661
662
663
664
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
665
        worker_use_ray: Optional[bool] = None,
666
        max_parallel_loading_workers: Optional[int] = None,
667
        disable_custom_all_reduce: bool = False,
668
        tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
669
        ray_workers_use_nsight: bool = False,
670
        placement_group: Optional["PlacementGroup"] = None,
671
672
        distributed_executor_backend: Optional[Union[
            str, Type["ExecutorBase"]]] = None,
673
674
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
675
        self.tensor_parallel_size = tensor_parallel_size
676
        self.distributed_executor_backend = distributed_executor_backend
677
        self.max_parallel_loading_workers = max_parallel_loading_workers
678
        self.disable_custom_all_reduce = disable_custom_all_reduce
679
        self.tokenizer_pool_config = tokenizer_pool_config
680
        self.ray_workers_use_nsight = ray_workers_use_nsight
681
        self.placement_group = placement_group
682

683
        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
684
685
686
        if worker_use_ray:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
687
            elif not self.use_ray:
688
689
690
691
692
                raise ValueError(f"worker-use-ray can't be used with "
                                 f"distributed executor backend "
                                 f"'{self.distributed_executor_backend}'.")

        if self.distributed_executor_backend is None and self.world_size > 1:
693
694
695
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

696
            from vllm.executor import ray_utils
697
            backend = "mp"
698
            ray_found = ray_utils.ray_is_available()
699
            if cuda_device_count_stateless() < self.world_size:
700
701
                if not ray_found:
                    raise ValueError("Unable to load Ray which is "
702
703
704
                                     "required for multi-node inference, "
                                     "please install Ray with `pip install "
                                     "ray`.") from ray_utils.ray_import_err
705
706
                backend = "ray"
            elif ray_found:
707
                if self.placement_group:
708
                    backend = "ray"
709
710
711
712
713
714
                else:
                    from ray import is_initialized as ray_is_initialized
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
                        if get_current_placement_group():
                            backend = "ray"
715
716
717
            self.distributed_executor_backend = backend
            logger.info("Defaulting to use %s for distributed inference",
                        backend)
718

719
        self._verify_args()
720
        self.rank = 0
721

722
723
724
725
726
727
    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
            and self.distributed_executor_backend.uses_ray)

728
    def _verify_args(self) -> None:
729
730
731
732
733
734
735
        # Lazy import to avoid circular import
        from vllm.executor.executor_base import ExecutorBase

        if self.distributed_executor_backend not in (
                "ray", "mp", None) and not (isinstance(
                    self.distributed_executor_backend, type) and issubclass(
                        self.distributed_executor_backend, ExecutorBase)):
736
            raise ValueError(
737
738
739
740
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
                "values are 'ray', 'mp' or custom ExecutorBase subclass.")
        if self.use_ray:
741
742
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()
743
744
745
746
747
        if is_hip():
            self.disable_custom_all_reduce = True
            logger.info(
                "Disabled the custom all-reduce kernel because it is not "
                "supported on AMD GPUs.")
748
        if self.ray_workers_use_nsight and not self.use_ray:
749
750
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
751

752
753

class SchedulerConfig:
754
755
756
757
758
759
760
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
761
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
762
            and generated text).
763
764
765
766
767
        use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
        num_lookahead_slots: The number of slots to allocate per sequence per
            step, beyond the known token ids. This is used in speculative
            decoding to store KV activations of tokens which may or may not be
            accepted.
768
769
        delay_factor: Apply a delay (of delay factor multiplied by previous
            prompt latency) before scheduling next prompt.
770
771
        enable_chunked_prefill: If True, prefill requests can be chunked based
            on the remaining max_num_batched_tokens.
772
        embedding_mode: Whether the running model is for embedding.
773
774
775
776
777
778
        preemption_mode: Whether to perform preemption by swapping or 
            recomputation. If not specified, we determine the mode as follows:
            We use recomputation by default since it incurs lower overhead than
            swapping. However, when the sequence group has multiple sequences
            (e.g., beam search), recomputation is not currently supported. In
            such a case, we use swapping instead.
779
    """
780

781
782
783
784
785
786
787
788
789
790
    def __init__(self,
                 max_num_batched_tokens: Optional[int],
                 max_num_seqs: int,
                 max_model_len: int,
                 use_v2_block_manager: bool = False,
                 num_lookahead_slots: int = 0,
                 delay_factor: float = 0.0,
                 enable_chunked_prefill: bool = False,
                 embedding_mode: Optional[bool] = False,
                 preemption_mode: Optional[str] = None) -> None:
791
792
793
        if max_num_batched_tokens is not None:
            self.max_num_batched_tokens = max_num_batched_tokens
        else:
794
            if enable_chunked_prefill:
795
796
797
                # It is the values that have the best balance between ITL
                # and TTFT on A100. Note it is not optimized for throughput.
                self.max_num_batched_tokens = 512
798
799
800
801
            elif embedding_mode:
                # For embedding, choose specific value for higher throughput
                self.max_num_batched_tokens = max(
                    max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
802
803
804
805
806
            else:
                # If max_model_len is too short, use 2048 as the default value
                # for higher throughput.
                self.max_num_batched_tokens = max(max_model_len, 2048)
        if enable_chunked_prefill:
807
808
            logger.info(
                "Chunked prefill is enabled with max_num_batched_tokens=%d.",
809
                self.max_num_batched_tokens)
810

811
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
812
        self.max_model_len = max_model_len
813
        self.use_v2_block_manager = use_v2_block_manager
814
815
        self.num_lookahead_slots = num_lookahead_slots
        self.delay_factor = delay_factor
816
        self.chunked_prefill_enabled = enable_chunked_prefill
817
        self.embedding_mode = embedding_mode
818
        self.preemption_mode = preemption_mode
819
820
821
        self._verify_args()

    def _verify_args(self) -> None:
822
823
        if (self.max_num_batched_tokens < self.max_model_len
                and not self.chunked_prefill_enabled):
824
825
826
827
828
829
830
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
831

832
833
834
835
836
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
837

838
839
840
841
842
843
        if self.num_lookahead_slots < 0:
            raise ValueError(
                "num_lookahead_slots "
                f"({self.num_lookahead_slots}) must be greater than or "
                "equal to 0.")

844

845
846
class DeviceConfig:

847
848
849
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
850
            if is_neuron():
851
                self.device_type = "neuron"
852
853
            elif is_openvino():
                self.device_type = "openvino"
854
855
            elif is_tpu():
                self.device_type = "tpu"
856
857
            elif is_cpu():
                self.device_type = "cpu"
858
859
            elif is_xpu():
                self.device_type = "xpu"
860
            else:
861
862
863
                # We don't call torch.cuda.is_available() here to
                # avoid initializing CUDA before workers are forked
                self.device_type = "cuda"
864
865
866
867
868
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
869
        if self.device_type in ["neuron", "openvino"]:
870
            self.device = torch.device("cpu")
871
872
        elif self.device_type in ["tpu"]:
            self.device = None
873
874
875
876
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

877

878
879
880
881
882
883
884
885
886
887
888
889
890
class SpeculativeConfig:
    """Configuration for speculative decoding.

    The configuration is currently specialized to draft-model speculative
    decoding with top-1 proposals.
    """

    @staticmethod
    def maybe_create_spec_config(
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        target_dtype: str,
        speculative_model: Optional[str],
891
        speculative_draft_tensor_parallel_size: Optional[int],
892
        num_speculative_tokens: Optional[int],
893
894
895
        speculative_max_model_len: Optional[int],
        enable_chunked_prefill: bool,
        use_v2_block_manager: bool,
896
        speculative_disable_by_batch_size: Optional[int],
897
898
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
899
900
901
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: Optional[float],
        typical_acceptance_sampler_posterior_alpha: Optional[float],
902
        disable_logprobs: Optional[bool],
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
    ) -> Optional["SpeculativeConfig"]:
        """Create a SpeculativeConfig if possible, else return None.

        This function attempts to create a SpeculativeConfig object based on the
        provided parameters. If the necessary conditions are met, it returns an
        instance of SpeculativeConfig. Otherwise, it returns None.

        Args:
            target_model_config (ModelConfig): The configuration of the target
                model.
            target_parallel_config (ParallelConfig): The parallel configuration
                for the target model.
            target_dtype (str): The data type used for the target model.
            speculative_model (Optional[str]): The name of the speculative
                model, if provided.
918
919
            speculative_draft_tensor_parallel_size (Optional[int]): The degree
                of the tensor parallelism for the draft model.
920
            num_speculative_tokens (Optional[int]): The number of speculative
921
922
                tokens, if provided. Will default to the number in the draft
                model config if present, otherwise is required.
923
924
925
926
927
928
929
930
931
            speculative_max_model_len (Optional[int]): The maximum model len of
                the speculative model. Used when testing the ability to skip
                speculation for some sequences.
            enable_chunked_prefill (bool): Whether vLLM is configured to use
                chunked prefill or not. Used for raising an error since its not
                yet compatible with spec decode.
            use_v2_block_manager (bool): Whether vLLM is configured to use the
                v2 block manager or not. Used for raising an error since the v2
                block manager is required with spec decode.
932
933
934
            speculative_disable_by_batch_size (Optional[int]): Disable
                speculative decoding for new incoming requests when the number
                of enqueue requests  is larger than this value, if provided.
935
936
937
938
            ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
                window, if provided.
            ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
                window, if provided.
939
940
941
942
943
944
945
946
947
948
949
950
951
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
                accepted. This threshold is used only when we use the 
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
952
953
954
955
956
            disable_logprobs (Optional[bool]): If set to True, token log
                probabilities are not returned during speculative decoding.
                If set to False, token log probabilities are returned
                according to the log probability settings in SamplingParams.
                If not specified, it defaults to True.
957
    
958
959
960
961
962
        Returns:
            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                the necessary conditions are met, else None.
        """

963
964
965
966
        if speculative_model is None:
            if num_speculative_tokens is not None:
                raise ValueError("num_speculative_tokens was provided without "
                                 "speculative_model.")
967
968
            return None

969
970
971
972
973
974
        if (speculative_disable_by_batch_size is not None
                and speculative_disable_by_batch_size < 2):
            raise ValueError("Expect the batch size threshold of disabling "
                             "speculative decoding is > 1, but got "
                             f"{speculative_disable_by_batch_size=}")

975
976
977
978
979
980
981
982
983
984
        if enable_chunked_prefill:
            raise ValueError(
                "Speculative decoding and chunked prefill are "
                f"currently mutually exclusive ({enable_chunked_prefill=}).")

        if not use_v2_block_manager:
            raise ValueError(
                "Speculative decoding requires usage of the V2 "
                "block manager. Enable it with --use-v2-block-manager.")

985
986
987
988
989
990
        # TODO: The user should be able to specify revision/quantization/max
        # model len for the draft model. It is not currently supported.
        draft_revision = None
        draft_code_revision = None
        draft_quantization = None

991
992
        if speculative_model == "[ngram]":
            if ngram_prompt_lookup_min is None:
993
994
995
996
997
998
999
1000
                ngram_prompt_lookup_min = 1
            if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
                raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
            if ngram_prompt_lookup_min < 1:
                raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
            if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
                raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
                                 f"larger than {ngram_prompt_lookup_max=}")
1001

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            draft_model_config = target_model_config
            draft_parallel_config = target_parallel_config
        else:
            ngram_prompt_lookup_max = 0
            ngram_prompt_lookup_min = 0
            draft_model_config = ModelConfig(
                model=speculative_model,
                tokenizer=target_model_config.tokenizer,
                tokenizer_mode=target_model_config.tokenizer_mode,
                trust_remote_code=target_model_config.trust_remote_code,
                dtype=target_model_config.dtype,
                seed=target_model_config.seed,
                revision=draft_revision,
                code_revision=draft_code_revision,
                tokenizer_revision=target_model_config.tokenizer_revision,
                max_model_len=None,
                quantization=draft_quantization,
                enforce_eager=target_model_config.enforce_eager,
1023
1024
                max_seq_len_to_capture=target_model_config.
                max_seq_len_to_capture,
1025
1026
1027
                max_logprobs=target_model_config.max_logprobs,
            )

1028
            draft_hf_config = draft_model_config.hf_config
1029

1030
1031
1032
1033
1034
            if (num_speculative_tokens is not None
                    and hasattr(draft_hf_config, "num_lookahead_tokens")):
                draft_hf_config.num_lookahead_tokens = num_speculative_tokens

            n_predict = getattr(draft_hf_config, "n_predict", None)
1035
1036
1037
1038
1039
1040
1041
1042
            if n_predict is not None:
                if num_speculative_tokens is None:
                    # Default to max value defined in draft model config.
                    num_speculative_tokens = n_predict
                elif num_speculative_tokens > n_predict:
                    # Verify provided value doesn't exceed the maximum
                    # supported by the draft model.
                    raise ValueError(
1043
1044
1045
                        "This speculative model supports a maximum of "
                        f"num_speculative_tokens={n_predict}, but "
                        f"{num_speculative_tokens=} was provided.")
1046

1047
1048
1049
1050
1051
1052
1053
1054
1055
            draft_model_config.max_model_len = (
                SpeculativeConfig._maybe_override_draft_max_model_len(
                    speculative_max_model_len,
                    draft_model_config.max_model_len,
                    target_model_config.max_model_len,
                ))

            draft_parallel_config = (
                SpeculativeConfig.create_draft_parallel_config(
1056
1057
                    target_parallel_config,
                    speculative_draft_tensor_parallel_size))
1058

1059
1060
1061
1062
1063
1064
        if num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative_model unless the draft model config contains an "
                "n_predict parameter.")

1065
1066
1067
1068
        if typical_acceptance_sampler_posterior_threshold is None:
            typical_acceptance_sampler_posterior_threshold = 0.09
        if typical_acceptance_sampler_posterior_alpha is None:
            typical_acceptance_sampler_posterior_alpha = 0.3
1069
1070
        if disable_logprobs is None:
            disable_logprobs = True
1071

1072
1073
1074
1075
        return SpeculativeConfig(
            draft_model_config,
            draft_parallel_config,
            num_speculative_tokens,
1076
            speculative_disable_by_batch_size,
1077
1078
            ngram_prompt_lookup_max,
            ngram_prompt_lookup_min,
1079
1080
1081
1082
1083
            draft_token_acceptance_method=draft_token_acceptance_method,
            typical_acceptance_sampler_posterior_threshold=\
                typical_acceptance_sampler_posterior_threshold,
            typical_acceptance_sampler_posterior_alpha=\
                typical_acceptance_sampler_posterior_alpha,
1084
            disable_logprobs=disable_logprobs
1085
1086
        )

1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
    @staticmethod
    def _maybe_override_draft_max_model_len(
        speculative_max_model_len: Optional[int],
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:

            if speculative_max_model_len > draft_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {draft_max_model_len=}")

            if speculative_max_model_len > target_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {target_max_model_len=}")

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

1122
1123
    @staticmethod
    def create_draft_parallel_config(
1124
1125
1126
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: Optional[int]
    ) -> ParallelConfig:
1127
1128
        """Create a parallel config for use by the draft worker.

1129
        This is mostly a copy of the target parallel config, except the tp_size.
1130
        """
1131
1132
1133
1134
1135
1136
1137
1138
1139
        if speculative_draft_tensor_parallel_size is None:
            speculative_draft_tensor_parallel_size = \
                  target_parallel_config.tensor_parallel_size
        elif speculative_draft_tensor_parallel_size != 1:
            # TODO(wooyeon): allow tp values larger than 1
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be"
                f"other value than 1")

1140
1141
1142
        draft_parallel_config = ParallelConfig(
            pipeline_parallel_size=target_parallel_config.
            pipeline_parallel_size,
1143
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
1144
1145
            distributed_executor_backend=target_parallel_config.
            distributed_executor_backend,
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
            max_parallel_loading_workers=target_parallel_config.
            max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.
            disable_custom_all_reduce,
            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
            ray_workers_use_nsight=target_parallel_config.
            ray_workers_use_nsight,
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

    def __init__(
        self,
        draft_model_config: ModelConfig,
        draft_parallel_config: ParallelConfig,
        num_speculative_tokens: int,
1163
1164
1165
        speculative_disable_by_batch_size: Optional[int],
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1166
1167
1168
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
1169
        disable_logprobs: bool,
1170
1171
1172
1173
1174
1175
1176
1177
    ):
        """Create a SpeculativeConfig object.

        Args:
            draft_model_config: ModelConfig for the draft model.
            draft_parallel_config: ParallelConfig for the draft model.
            num_speculative_tokens: The number of tokens to sample from the
                draft model before scoring with the target model.
1178
1179
1180
1181
1182
            speculative_disable_by_batch_size: Disable speculative
                decoding for new incoming requests when the number of
                enqueue requests is larger than this value.
            ngram_prompt_lookup_max: Max size of ngram token window.
            ngram_prompt_lookup_min: Min size of ngram token window.
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
            draft_token_acceptance_method (str): The method to use for
                accepting draft tokens. This can take two possible
                values 'rejection_sampler' and 'typical_acceptance_sampler'
                for RejectionSampler and TypicalAcceptanceSampler
                respectively.
            typical_acceptance_sampler_posterior_threshold (Optional[float]):
                A threshold value that sets a lower bound on the posterior
                probability of a token in the target model for it to be
                accepted. This threshold is used only when we use the 
                TypicalAcceptanceSampler for token acceptance.
            typical_acceptance_sampler_posterior_alpha (Optional[float]):
                A scaling factor for the entropy-based threshold in the
                TypicalAcceptanceSampler.
1196
1197
1198
1199
1200
1201
            disable_logprobs: If set to True, token log probabilities will not
                be returned even if requested by sampling parameters. This 
                reduces latency by skipping logprob calculation in proposal
                sampling, target sampling, and after accepted tokens are
                determined. If set to False, log probabilities will be
                returned.
1202
1203
1204
1205
        """
        self.draft_model_config = draft_model_config
        self.draft_parallel_config = draft_parallel_config
        self.num_speculative_tokens = num_speculative_tokens
1206
1207
1208
1209
        self.speculative_disable_by_batch_size = \
            speculative_disable_by_batch_size
        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
1210
1211
1212
1213
1214
        self.draft_token_acceptance_method = draft_token_acceptance_method
        self.typical_acceptance_sampler_posterior_threshold = \
            typical_acceptance_sampler_posterior_threshold
        self.typical_acceptance_sampler_posterior_alpha = \
            typical_acceptance_sampler_posterior_alpha
1215
        self.disable_logprobs = disable_logprobs
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226

        self._verify_args()

    def _verify_args(self) -> None:
        if self.num_speculative_tokens <= 0:
            raise ValueError("Expected num_speculative_tokens to be greater "
                             f"than zero ({self.num_speculative_tokens}).")

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
                self.draft_parallel_config)
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
            # Validate and set draft token acceptance related settings.

        if (self.draft_token_acceptance_method is None):
            raise ValueError("draft_token_acceptance_method is not set. "
                             "Expected values are rejection_sampler or "
                             "typical_acceptance_sampler.")

        if (self.draft_token_acceptance_method != 'rejection_sampler'
                and self.draft_token_acceptance_method !=
                'typical_acceptance_sampler'):
            raise ValueError(
                "Expected draft_token_acceptance_method to be either "
                "rejection_sampler or typical_acceptance_sampler. Instead it "
                f"is {self.draft_token_acceptance_method}")

        if (self.typical_acceptance_sampler_posterior_threshold < 0
                or self.typical_acceptance_sampler_posterior_alpha < 0):
            raise ValueError(
                "Expected typical_acceptance_sampler_posterior_threshold "
                "and typical_acceptance_sampler_posterior_alpha to be > 0. "
                "Instead found "
                f"typical_acceptance_sampler_posterior_threshold = "
                f"{self.typical_acceptance_sampler_posterior_threshold} and "
                f"typical_acceptance_sampler_posterior_alpha = "
                f"{self.typical_acceptance_sampler_posterior_alpha}")
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def __repr__(self) -> str:
1264
1265
1266
1267
        if self.ngram_prompt_lookup_max > 0:
            draft_model = "[ngram]"
        else:
            draft_model = self.draft_model_config.model
1268
1269
1270
1271
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


1272
1273
1274
1275
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
1276
    fully_sharded_loras: bool = False
1277
1278
1279
1280
1281
    max_cpu_loras: Optional[int] = None
    lora_dtype: Optional[torch.dtype] = None
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256
1282
    long_lora_scaling_factors: Optional[Tuple[float]] = None
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302

    def __post_init__(self):
        # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
        possible_max_ranks = (8, 16, 32, 64)
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
1303
                f"max_loras ({self.max_loras})")
1304
1305
1306
1307
1308
1309

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
1310
1311
1312
1313
        if model_config.quantization and model_config.quantization not in [
                "awq", "gptq"
        ]:
            # TODO support marlin and squeezellm
1314
1315
            logger.warning("%s quantization is not tested with LoRA yet.",
                           model_config.quantization)
1316
1317
1318
1319
1320
1321
1322

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
        if scheduler_config.max_num_batched_tokens > 65528:
            raise ValueError(
                "Due to limitations of the custom LoRA CUDA kernel, "
                "max_num_batched_tokens must be <= 65528 when "
                "LoRA is enabled.")
1323
1324
        if scheduler_config.chunked_prefill_enabled:
            raise ValueError("LoRA is not supported with chunked prefill yet.")
1325
1326


1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
@dataclass
class PromptAdapterConfig:
    max_prompt_adapters: int
    max_prompt_adapter_token: int
    max_cpu_prompt_adapters: Optional[int] = None
    prompt_adapter_dtype: Optional[torch.dtype] = None

    def __post_init__(self):
        library_name = 'peft'
        try:
            __import__(library_name)
        except ImportError as e:
            raise ImportError(
                f"'{library_name}' is not installed for prompt adapter support."
                f"Please install it using 'pip install {library_name}'."
            ) from e

        if self.max_prompt_adapters < 1:
            raise ValueError(f"max_prompt_adapters "
                             f"({self.max_prompt_adapters}) must be >= 1.")
        if self.max_prompt_adapter_token == 0:
            raise ValueError("max_prompt_adapter_token must be set.")
        if self.max_cpu_prompt_adapters is None:
            self.max_cpu_prompt_adapters = self.max_prompt_adapters

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.prompt_adapter_dtype in (None, "auto"):
            self.prompt_adapter_dtype = model_config.dtype
        elif isinstance(self.prompt_adapter_dtype, str):
            self.prompt_adapter_dtype = getattr(torch,
                                                self.prompt_adapter_dtype)


1360
@dataclass
1361
class MultiModalConfig:
1362
    """Configs the input data format and how models should run for
1363
1364
1365
    multimodal models."""
    # TODO: Add configs to init vision tower or not.
    pass
1366

1367

1368
1369
1370
1371
1372
1373
1374
1375
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

1376
_ROCM_NOT_SUPPORTED_DTYPE: List[str] = []  #
1377

1378
1379
1380

def _get_and_verify_dtype(
    config: PretrainedConfig,
1381
    dtype: Union[str, torch.dtype],
1382
1383
1384
1385
1386
1387
1388
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

1389
1390
1391
1392
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
Woosuk Kwon's avatar
Woosuk Kwon committed
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
                if config.model_type == "gemma2":
                    logger.info(
                        "For Gemma 2, we downcast float32 to bfloat16 instead "
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16.")
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
1403
1404
            else:
                torch_dtype = config_dtype
1405
        else:
1406
1407
1408
1409
1410
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
1411
    else:
1412
        raise ValueError(f"Unknown dtype: {dtype}")
1413
1414
1415
1416
1417

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
1418
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
1419
1420
1421
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
1422
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
1423
1424
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
1425
            # Casting between float16 and bfloat16 is allowed with a warning.
1426
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
1427
1428

    return torch_dtype
1429
1430
1431
1432
1433


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
1434
1435
    disable_sliding_window: bool,
    sliding_window_len: Optional[int],
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
1446
1447
        # ChatGLM2
        "seq_length",
1448
1449
        # Command-R
        "model_max_length",
1450
1451
1452
1453
1454
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
1455
    # Choose the smallest "max_length" from the possible keys.
1456
    max_len_key = None
1457
    for key in possible_keys:
1458
1459
1460
1461
1462
        max_len = getattr(hf_config, key, None)
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472

    # If sliding window is manually disabled, max_length should be less
    # than the sliding window length in the model config.
    if disable_sliding_window and sliding_window_len is not None:
        max_len_key = "sliding_window" \
            if sliding_window_len < derived_max_model_len else max_len_key
        derived_max_model_len = min(derived_max_model_len, sliding_window_len)

    # If none of the keys were found in the config, use a default and
    # log a warning.
1473
    if derived_max_model_len == float("inf"):
1474
1475
1476
1477
1478
1479
1480
1481
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
1482
            "%s. Assuming the model's maximum length is %d.", possible_keys,
1483
            default_max_len)
1484
        derived_max_model_len = default_max_len
1485

1486
    rope_scaling = getattr(hf_config, "rope_scaling", None)
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
    if rope_scaling is not None:
        if "type" in rope_scaling:
            rope_type = rope_scaling["type"]
        elif "rope_type" in rope_scaling:
            rope_type = rope_scaling["rope_type"]
        else:
            raise ValueError(
                "rope_scaling must have a 'type' or 'rope_type' key.")

        # The correct one should be "longrope", kept "su" here
        # to be backward compatible
        if rope_type not in ("su", "longrope", "llama3"):
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that supports rope_scaling
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "with rope_scaling. Please raise an issue so we can "
                    "investigate.")

            assert "factor" in rope_scaling
            scaling_factor = rope_scaling["factor"]
            if rope_type == "yarn":
                derived_max_model_len = rope_scaling[
                    "original_max_position_embeddings"]
            derived_max_model_len *= scaling_factor
1513

1514
1515
    # If the user specified a max length, make sure it is smaller than the
    # derived length from the HF model config.
1516
    if max_model_len is None:
1517
        max_model_len = int(derived_max_model_len)
1518
    elif max_model_len > derived_max_model_len:
1519
1520
1521
1522
1523
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, "model_max_length", None)
        if model_max_length is not None and max_model_len <= model_max_length:
1524
1525
1526
1527
1528
1529
1530
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that has model_max_length
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "model_max_length in the config. Please raise an issue "
                    "so we can investigate.")
1531
1532
1533
1534
1535
1536
1537
1538
1539
            pass
        else:
            raise ValueError(
                f"User-specified max_model_len ({max_model_len}) is greater "
                "than the derived max_model_len "
                f"({max_len_key}={derived_max_model_len} or model_max_length="
                f"{model_max_length} in model's config.json). This may lead "
                "to incorrect model outputs or CUDA errors. Make sure the "
                "value is correct and within the model context size.")
1540
    return int(max_model_len)
1541
1542


1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
def get_served_model_name(model: str,
                          served_model_name: Optional[Union[str, List[str]]]):
    """
    If the input is a non-empty list, the first model_name in 
    `served_model_name` is taken. 
    If the input is a non-empty string, it is used directly. 
    For cases where the input is either an empty string or an 
    empty list, the fallback is to use `self.model`.
    """
    if not served_model_name:
        return model
    if isinstance(served_model_name, list):
        return served_model_name[0]
    return served_model_name


1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
@dataclass
class DecodingConfig:
    """Dataclass which contains the decoding strategy of the engine"""

    # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
    guided_decoding_backend: str = 'outlines'

    def __post_init__(self):
        valid_guided_backends = ['outlines', 'lm-format-enforcer']
        backend = self.guided_decoding_backend
        if backend not in valid_guided_backends:
            raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                             f"must be one of {valid_guided_backends}")


1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
@dataclass
class ObservabilityConfig:
    """Configuration for observability."""
    otlp_traces_endpoint: Optional[str] = None

    def __post_init__(self):
        if not is_otel_installed() and self.otlp_traces_endpoint is not None:
            raise ValueError("OpenTelemetry packages must be installed before "
                             "configuring 'otlp_traces_endpoint'")


1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
@dataclass(frozen=True)
class EngineConfig:
    """Dataclass which contains all engine-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    model_config: ModelConfig
    cache_config: CacheConfig
    parallel_config: ParallelConfig
    scheduler_config: SchedulerConfig
    device_config: DeviceConfig
1596
    load_config: LoadConfig
1597
    lora_config: Optional[LoRAConfig]
1598
    multimodal_config: Optional[MultiModalConfig]
1599
    speculative_config: Optional[SpeculativeConfig]
1600
    decoding_config: Optional[DecodingConfig]
1601
    observability_config: Optional[ObservabilityConfig]
1602
    prompt_adapter_config: Optional[PromptAdapterConfig]
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613

    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
        self.model_config.verify_with_parallel_config(self.parallel_config)
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)
1614
1615
1616
        if self.prompt_adapter_config:
            self.prompt_adapter_config.verify_with_model_config(
                self.model_config)
1617
1618
1619
1620
1621
1622

    def to_dict(self):
        """Return the configs as a dictionary, for use in **kwargs.
        """
        return dict(
            (field.name, getattr(self, field.name)) for field in fields(self))