config.py 64 KB
Newer Older
1
import enum
2
import json
3
from dataclasses import dataclass, field, fields
4
5
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
                    Union)
6
7

import torch
8
from transformers import PretrainedConfig, PreTrainedTokenizerBase
9

10
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
from vllm.model_executor.models import ModelRegistry
14
from vllm.tracing import is_otel_installed
15
from vllm.transformers_utils.config import get_config, get_hf_text_config
16
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
Woosuk Kwon's avatar
Woosuk Kwon committed
17
                        is_hip, is_neuron, is_tpu, is_xpu, print_warning_once,
18
                        update_environment_variables)
19

20
21
22
if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

23
    from vllm.model_executor.model_loader.loader import BaseModelLoader
24

25
26
logger = init_logger(__name__)

27
_GB = 1 << 30
28
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
29

30
31

class ModelConfig:
32
33
34
35
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
36
37
            It is also used as the content for `model_name` tag in metrics 
            output when `served_model_name` is not specified. 
38
        tokenizer: Name or path of the huggingface tokenizer to use.
39
40
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
            available, and "slow" will always use the slow tokenizer.
41
42
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
43
44
45
46
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
47
48
49
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
50
        code_revision: The specific revision to use for the model code on
51
            Hugging Face Hub. It can be a branch name, a tag name, or a
52
            commit id. If unspecified, will use the default version.
53
54
55
        rope_scaling: Dictionary containing the scaling configuration for the
            RoPE embeddings. When using this flag, don't update
            `max_position_embeddings` to the expected new maximum.
56
57
58
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
59
60
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
61
62
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
63
64
        quantization_param_path: Path to JSON file containing scaling factors.
            Used to load KV cache scaling factors into the model when KV cache
65
66
            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
            be used to load activation and weight scaling factors when the
67
            model dtype is FP8_E4M3 on ROCm.
68
69
70
71
72
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
73
74
75
76
            to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode
77
78
79
80
        disable_sliding_window: Whether to disable sliding window. If True,
            we will disable the sliding window functionality of the model.
            If the model does not support sliding window, this argument is
            ignored.
81
82
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer.
83
84
85
86
        served_model_name: The model name used in metrics tag `model_name`,
            matches the model name exposed via the APIs. If multiple model 
            names provided, the first name will be used. If not specified, 
            the model name will be the same as `model`.
87
    """
88
89
90
91

    def __init__(
        self,
        model: str,
92
93
        tokenizer: str,
        tokenizer_mode: str,
94
        trust_remote_code: bool,
95
        dtype: Union[str, torch.dtype],
96
        seed: int,
97
        revision: Optional[str] = None,
98
        code_revision: Optional[str] = None,
99
        rope_scaling: Optional[dict] = None,
100
        rope_theta: Optional[float] = None,
101
        tokenizer_revision: Optional[str] = None,
102
        max_model_len: Optional[int] = None,
103
        quantization: Optional[str] = None,
104
        quantization_param_path: Optional[str] = None,
105
106
        enforce_eager: bool = False,
        max_context_len_to_capture: Optional[int] = None,
107
        max_seq_len_to_capture: Optional[int] = None,
108
        max_logprobs: int = 20,
109
        disable_sliding_window: bool = False,
110
        skip_tokenizer_init: bool = False,
111
        served_model_name: Optional[Union[str, List[str]]] = None,
112
113
    ) -> None:
        self.model = model
114
        self.tokenizer = tokenizer
115
        self.tokenizer_mode = tokenizer_mode
116
        self.trust_remote_code = trust_remote_code
117
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
118
        self.revision = revision
119
        self.code_revision = code_revision
120
        self.rope_scaling = rope_scaling
121
        self.rope_theta = rope_theta
122
123
124
125
126
        # The tokenizer version is consistent with the model version by default.
        if tokenizer_revision is None:
            self.tokenizer_revision = revision
        else:
            self.tokenizer_revision = tokenizer_revision
127
        self.quantization = quantization
128
        self.quantization_param_path = quantization_param_path
129
130
        self.enforce_eager = enforce_eager
        self.max_context_len_to_capture = max_context_len_to_capture
131
132
133
134
135
        if self.max_context_len_to_capture is not None:
            raise ValueError("`max_context_len_to_capture` is deprecated. "
                             "Use `max_seq_len_to_capture` instead.")
        self.max_seq_len_to_capture = (max_seq_len_to_capture
                                       or max_context_len_to_capture)
136
        self.max_logprobs = max_logprobs
137
        self.disable_sliding_window = disable_sliding_window
138
        self.skip_tokenizer_init = skip_tokenizer_init
139

140
        self.hf_config = get_config(self.model, trust_remote_code, revision,
141
                                    code_revision, rope_scaling, rope_theta)
142
143
        self.hf_text_config = get_hf_text_config(self.hf_config)
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146
147
148
149
150
151
152
153
154

        if (not self.disable_sliding_window
                and self.hf_text_config.model_type == "gemma2"
                and self.hf_text_config.sliding_window is not None):
            print_warning_once(
                "Gemma 2 uses sliding window attention for every odd layer, "
                "which is currently not supported by vLLM. Disabling sliding "
                "window and capping the max length to the sliding window size "
                f"({self.hf_text_config.sliding_window}).")
            self.disable_sliding_window = True

155
156
157
158
159
        self.max_model_len = _get_and_verify_max_len(
            hf_config=self.hf_text_config,
            max_model_len=max_model_len,
            disable_sliding_window=self.disable_sliding_window,
            sliding_window_len=self.get_hf_config_sliding_window())
160
161
        self.served_model_name = get_served_model_name(model,
                                                       served_model_name)
162
163
        if not self.skip_tokenizer_init:
            self._verify_tokenizer_mode()
164
        self._verify_embedding_mode()
165
        self._verify_quantization()
166
        self._verify_cuda_graph()
167
168
169
170
171
172
173
174

    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
        if tokenizer_mode not in ["auto", "slow"]:
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                "either 'auto' or 'slow'.")
        self.tokenizer_mode = tokenizer_mode
175

176
177
178
179
180
    def _verify_embedding_mode(self) -> None:
        architectures = getattr(self.hf_config, "architectures", [])
        self.embedding_mode = any(
            ModelRegistry.is_embedding_model(arch) for arch in architectures)

181
182
183
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
184
185
            # compress-tensors uses a "compression_config" key
            quant_cfg = getattr(self.hf_config, "compression_config", None)
186
187
        return quant_cfg

188
    def _verify_quantization(self) -> None:
189
190
        supported_quantization = [*QUANTIZATION_METHODS]
        rocm_supported_quantization = ["gptq", "squeezellm"]
191
192
193
194
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
195
196
        quant_cfg = self._parse_quant_hf_config()

197
198
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
199
200

            # Detect which checkpoint is it
201
            for _, method in QUANTIZATION_METHODS.items():
202
203
204
205
206
207
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization)
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break
208

209
            # Verify quantization configurations.
210
            if self.quantization is None:
211
212
                self.quantization = quant_method
            elif self.quantization != quant_method:
213
214
                raise ValueError(
                    "Quantization method specified in the model config "
215
                    f"({quant_method}) does not match the quantization "
216
217
218
219
220
221
222
223
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
224
            if is_hip(
225
            ) and self.quantization not in rocm_supported_quantization:
226
                raise ValueError(
227
228
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm.")
229
            if (self.quantization
Cody Yu's avatar
Cody Yu committed
230
                    not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")):
231
                logger.warning(
232
                    "%s quantization is not fully "
233
                    "optimized yet. The speed can be slower than "
234
                    "non-quantized models.", self.quantization)
235

236
    def _verify_cuda_graph(self) -> None:
237
238
239
240
        if self.max_seq_len_to_capture is None:
            self.max_seq_len_to_capture = self.max_model_len
        self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
                                          self.max_model_len)
241

242
243
244
245
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
246
247
        total_num_attention_heads = getattr(self.hf_text_config,
                                            "num_attention_heads", 0)
248
249
250
251
252
253
254
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

255
256
        total_num_hidden_layers = getattr(self.hf_text_config,
                                          "num_hidden_layers", 0)
257
258
259
260
261
262
263
        pipeline_parallel_size = parallel_config.pipeline_parallel_size
        if total_num_hidden_layers % pipeline_parallel_size != 0:
            raise ValueError(
                f"Total number of hidden layers ({total_num_hidden_layers}) "
                "must be divisible by pipeline parallel size "
                f"({pipeline_parallel_size}).")

264
265
266
267
268
269
        if self.quantization == "bitsandbytes" and (
                parallel_config.tensor_parallel_size > 1
                or parallel_config.pipeline_parallel_size > 1):
            raise ValueError(
                "BitAndBytes quantization with TP or PP is not supported yet.")

270
    def get_hf_config_sliding_window(self) -> Optional[int]:
Woosuk Kwon's avatar
Woosuk Kwon committed
271
        """Get the sliding window size, or None if disabled."""
272
273
274
275

        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
        # addition to sliding window size. We check if that field is present
        # and if it's False, return None.
276
277
        if (hasattr(self.hf_text_config, "use_sliding_window")
                and not self.hf_text_config.use_sliding_window):
278
            return None
279
        return getattr(self.hf_text_config, "sliding_window", None)
280

281
282
283
284
285
286
287
288
289
    def get_sliding_window(self) -> Optional[int]:
        """Get the sliding window size, or None if disabled.
        """
        # If user disables sliding window, return None.
        if self.disable_sliding_window:
            return None
        # Otherwise get the value from the hf config.
        return self.get_hf_config_sliding_window()

290
    def get_vocab_size(self) -> int:
291
        return self.hf_text_config.vocab_size
292

293
    def get_hidden_size(self) -> int:
294
        return self.hf_text_config.hidden_size
295
296

    def get_head_size(self) -> int:
297
298
        if hasattr(self.hf_text_config, "head_dim"):
            return self.hf_text_config.head_dim
299
        # FIXME(woosuk): This may not be true for all models.
300
301
        return (self.hf_text_config.hidden_size //
                self.hf_text_config.num_attention_heads)
302

303
304
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
305
        # For GPTBigCode & Falcon:
306
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
307
308
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
309
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
310
        new_decoder_arch_falcon = (
311
            self.hf_config.model_type in falcon_model_types
312
            and getattr(self.hf_config, "new_decoder_architecture", False))
313
        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
314
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
315
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
316
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
317
            return 1
318

319
        # For DBRX and MPT
320
321
322
323
324
        if self.hf_config.model_type == "mpt":
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type == "dbrx":
325
326
327
            return getattr(self.hf_config.attn_config, "kv_n_heads",
                           self.hf_config.num_attention_heads)

328
329
330
331
332
333
334
335
336
337
        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
338
            num_kv_heads = getattr(self.hf_text_config, attr, None)
339
340
341
342
343
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
344
        return self.hf_text_config.num_attention_heads
345
346
347
348
349
350
351
352
353
354

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
355

356
357
    def get_num_attention_heads(self,
                                parallel_config: "ParallelConfig") -> int:
358
359
        num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
        return num_heads // parallel_config.tensor_parallel_size
360

361
    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
362
        total_num_hidden_layers = self.hf_text_config.num_hidden_layers
363
364
365
366
        return total_num_hidden_layers // parallel_config.pipeline_parallel_size


class CacheConfig:
367
368
369
370
371
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
372
            vLLM execution.
373
        swap_space: Size of the CPU swap space per GPU (in GiB).
374
        cache_dtype: Data type for kv cache storage.
375
        num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
376
            profiled num_gpu_blocks if specified. Does nothing if None.
377
    """
378

379
380
381
382
383
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        swap_space: int,
384
        cache_dtype: str,
385
        num_gpu_blocks_override: Optional[int] = None,
386
        sliding_window: Optional[int] = None,
387
        enable_prefix_caching: bool = False,
388
389
390
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
391
        self.swap_space_bytes = swap_space * _GB
392
        self.num_gpu_blocks_override = num_gpu_blocks_override
393
        self.cache_dtype = cache_dtype
394
        self.sliding_window = sliding_window
395
        self.enable_prefix_caching = enable_prefix_caching
396
        self._verify_args()
397
        self._verify_cache_dtype()
398
        self._verify_prefix_caching()
399
400
401
402
403

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

404
    def metrics_info(self):
405
406
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
407
408
        return {key: str(value) for key, value in self.__dict__.items()}

409
410
411
412
413
414
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

415
416
417
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
418
        elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
419
            logger.info(
420
421
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
422
423
                "Meanwhile, it may cause accuracy drop without a proper "
                "scaling factor")
424
425
426
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

427
428
429
430
431
432
433
434
435
436
437
438
439
    def _verify_prefix_caching(self) -> None:
        if not self.enable_prefix_caching:
            return

        if self.sliding_window is not None:
            raise NotImplementedError(
                "Prefix caching is not supported with sliding window. "
                "Run with --disable-sliding-window to use prefix caching.")
        if self.cache_dtype == "fp8":
            raise NotImplementedError(
                "Prefix caching is not supported for fp8 cache_dtype. "
                "Run with --kv-cache-dtype auto to use prefix caching.")

440
441
442
443
444
445
446
447
448
449
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

450
451
452
        msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
               f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
               "allocated for the swap space.")
453
454
455
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
456
            logger.warning("Possibly too large swap space. %s", msg)
457

458

459
460
461
@dataclass
class TokenizerPoolConfig:
    """Configuration for the tokenizer pool.
462

463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    Args:
        pool_size: Number of tokenizer workers in the pool.
        pool_type: Type of the pool.
        extra_config: Additional config for the pool.
            The way the config will be used depends on the
            pool type.
    """
    pool_size: int
    pool_type: str
    extra_config: dict

    def __post_init__(self):
        if self.pool_type not in ("ray", ):
            raise ValueError(f"Unknown pool type: {self.pool_type}")
        if not isinstance(self.extra_config, dict):
            raise ValueError("extra_config must be a dictionary.")

    @classmethod
    def create_config(
        cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
        tokenizer_pool_extra_config: Optional[Union[str, dict]]
    ) -> Optional["TokenizerPoolConfig"]:
        """Create a TokenizerPoolConfig from the given parameters.
486

487
        If tokenizer_pool_size is 0, return None.
488

489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        Args:
            tokenizer_pool_size: Number of tokenizer workers in the pool.
            tokenizer_pool_type: Type of the pool.
            tokenizer_pool_extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type. This can be a JSON string (will be parsed).
        """
        if tokenizer_pool_size:
            if isinstance(tokenizer_pool_extra_config, str):
                tokenizer_pool_extra_config_parsed = json.loads(
                    tokenizer_pool_extra_config)
            else:
                tokenizer_pool_extra_config_parsed = (
                    tokenizer_pool_extra_config or {})
            tokenizer_pool_config = cls(tokenizer_pool_size,
                                        tokenizer_pool_type,
                                        tokenizer_pool_extra_config_parsed)
        else:
            tokenizer_pool_config = None
        return tokenizer_pool_config


511
512
513
514
515
516
517
class LoadFormat(str, enum.Enum):
    AUTO = "auto"
    PT = "pt"
    SAFETENSORS = "safetensors"
    NPCACHE = "npcache"
    DUMMY = "dummy"
    TENSORIZER = "tensorizer"
518
    SHARDED_STATE = "sharded_state"
519
    BITSANDBYTES = "bitsandbytes"
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571


@dataclass
class LoadConfig:
    """
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
            "tensorizer" will use CoreWeave's tensorizer library for
                fast weight loading.
    """

    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(
        default_factory=dict)

    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = json.loads(
                model_loader_extra_config)
        self._verify_load_format()

    def _verify_load_format(self) -> None:
        if not isinstance(self.load_format, str):
            return

        load_format = self.load_format.lower()
        self.load_format = LoadFormat(load_format)

        rocm_not_supported_load_format: List[str] = []
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f for f in LoadFormat.__members__
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format '{load_format}' is not supported in ROCm. "
                f"Supported load formats are "
                f"{rocm_supported_load_format}")


572
class ParallelConfig:
573
574
575
576
577
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
578
        worker_use_ray: Deprecated, use distributed_executor_backend instead.
zspo's avatar
zspo committed
579
580
581
        max_parallel_loading_workers: Maximum number of multiple batches
            when load model sequentially. To avoid RAM OOM when using tensor
            parallel and large models.
582
583
        disable_custom_all_reduce: Disable the custom all-reduce kernel and
            fall back to NCCL.
584
585
        tokenizer_pool_config: Config for the tokenizer pool.
            If None, will use synchronous tokenization.
586
587
        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
588
        placement_group: ray distributed model workers placement group.
589
590
591
592
        distributed_executor_backend: Backend to use for distributed model
            workers, either "ray" or "mp" (multiprocessing). If either
            pipeline_parallel_size or tensor_parallel_size is greater than 1,
            will default to "ray" if Ray is installed or "mp" otherwise.
593
    """
594

595
596
597
598
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
599
        worker_use_ray: Optional[bool] = None,
600
        max_parallel_loading_workers: Optional[int] = None,
601
        disable_custom_all_reduce: bool = False,
602
        tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
603
        ray_workers_use_nsight: bool = False,
604
        placement_group: Optional["PlacementGroup"] = None,
605
        distributed_executor_backend: Optional[str] = None,
606
607
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
608
        self.tensor_parallel_size = tensor_parallel_size
609
        self.distributed_executor_backend = distributed_executor_backend
610
        self.max_parallel_loading_workers = max_parallel_loading_workers
611
        self.disable_custom_all_reduce = disable_custom_all_reduce
612
        self.tokenizer_pool_config = tokenizer_pool_config
613
        self.ray_workers_use_nsight = ray_workers_use_nsight
614
        self.placement_group = placement_group
615

616
        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
617
618
619
620
621
622
623
624
625
        if worker_use_ray:
            if self.distributed_executor_backend is None:
                self.distributed_executor_backend = "ray"
            elif self.distributed_executor_backend != "ray":
                raise ValueError(f"worker-use-ray can't be used with "
                                 f"distributed executor backend "
                                 f"'{self.distributed_executor_backend}'.")

        if self.distributed_executor_backend is None and self.world_size > 1:
626
627
628
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

629
            from vllm.executor import ray_utils
630
            backend = "mp"
631
            ray_found = ray_utils.ray is not None
632
            if cuda_device_count_stateless() < self.world_size:
633
634
635
636
637
                if not ray_found:
                    raise ValueError("Unable to load Ray which is "
                                     "required for multi-node inference")
                backend = "ray"
            elif ray_found:
638
                if self.placement_group:
639
                    backend = "ray"
640
641
642
643
644
645
                else:
                    from ray import is_initialized as ray_is_initialized
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
                        if get_current_placement_group():
                            backend = "ray"
646
647
648
            self.distributed_executor_backend = backend
            logger.info("Defaulting to use %s for distributed inference",
                        backend)
649
650
651
652
653
654
        # If CUDA_VISIBLE_DEVICES is set on ROCm prior to vLLM init,
        # propagate changes to HIP_VISIBLE_DEVICES (conversion handled by
        # the update_environment_variables function)
        if is_hip() and envs.CUDA_VISIBLE_DEVICES:
            update_environment_variables(
                {"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
655

656
657
658
659
660
661
        self._verify_args()

    def _verify_args(self) -> None:
        if self.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is not supported yet.")
662
663
664
665
        if self.distributed_executor_backend not in ("ray", "mp", None):
            raise ValueError(
                "Unrecognized distributed executor backend. Supported values "
                "are 'ray' or 'mp'.")
666
667
668
669
670
671
672
673
674
675
676
        if not self.disable_custom_all_reduce and self.world_size > 1:
            if is_hip():
                self.disable_custom_all_reduce = True
                logger.info(
                    "Disabled the custom all-reduce kernel because it is not "
                    "supported on AMD GPUs.")
            elif self.pipeline_parallel_size > 1:
                self.disable_custom_all_reduce = True
                logger.info(
                    "Disabled the custom all-reduce kernel because it is not "
                    "supported with pipeline parallelism.")
677
678
        if self.ray_workers_use_nsight and (
                not self.distributed_executor_backend == "ray"):
679
680
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
681

682
683

class SchedulerConfig:
684
685
686
687
688
689
690
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
691
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
692
            and generated text).
693
694
695
696
697
        use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
        num_lookahead_slots: The number of slots to allocate per sequence per
            step, beyond the known token ids. This is used in speculative
            decoding to store KV activations of tokens which may or may not be
            accepted.
698
699
        delay_factor: Apply a delay (of delay factor multiplied by previous
            prompt latency) before scheduling next prompt.
700
701
        enable_chunked_prefill: If True, prefill requests can be chunked based
            on the remaining max_num_batched_tokens.
702
        embedding_mode: Whether the running model is for embedding.
703
704
705
706
707
708
        preemption_mode: Whether to perform preemption by swapping or 
            recomputation. If not specified, we determine the mode as follows:
            We use recomputation by default since it incurs lower overhead than
            swapping. However, when the sequence group has multiple sequences
            (e.g., beam search), recomputation is not currently supported. In
            such a case, we use swapping instead.
709
    """
710

711
712
713
714
715
716
717
718
719
720
    def __init__(self,
                 max_num_batched_tokens: Optional[int],
                 max_num_seqs: int,
                 max_model_len: int,
                 use_v2_block_manager: bool = False,
                 num_lookahead_slots: int = 0,
                 delay_factor: float = 0.0,
                 enable_chunked_prefill: bool = False,
                 embedding_mode: Optional[bool] = False,
                 preemption_mode: Optional[str] = None) -> None:
721
722
723
        if max_num_batched_tokens is not None:
            self.max_num_batched_tokens = max_num_batched_tokens
        else:
724
            if enable_chunked_prefill:
725
726
727
                # It is the values that have the best balance between ITL
                # and TTFT on A100. Note it is not optimized for throughput.
                self.max_num_batched_tokens = 512
728
729
730
731
            elif embedding_mode:
                # For embedding, choose specific value for higher throughput
                self.max_num_batched_tokens = max(
                    max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
732
733
734
735
736
737
738
            else:
                # If max_model_len is too short, use 2048 as the default value
                # for higher throughput.
                self.max_num_batched_tokens = max(max_model_len, 2048)
        if enable_chunked_prefill:
            logger.info("Chunked prefill is enabled (EXPERIMENTAL).")

739
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
740
        self.max_model_len = max_model_len
741
        self.use_v2_block_manager = use_v2_block_manager
742
743
        self.num_lookahead_slots = num_lookahead_slots
        self.delay_factor = delay_factor
744
        self.chunked_prefill_enabled = enable_chunked_prefill
745
        self.embedding_mode = embedding_mode
746
        self.preemption_mode = preemption_mode
747

748
749
750
        self._verify_args()

    def _verify_args(self) -> None:
751
752
        if (self.max_num_batched_tokens < self.max_model_len
                and not self.chunked_prefill_enabled):
753
754
755
756
757
758
759
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
760

761
762
763
764
765
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
766

767
768
769
770
771
772
        if self.num_lookahead_slots < 0:
            raise ValueError(
                "num_lookahead_slots "
                f"({self.num_lookahead_slots}) must be greater than or "
                "equal to 0.")

773

774
775
class DeviceConfig:

776
777
778
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
779
            if is_neuron():
780
                self.device_type = "neuron"
781
782
            elif is_tpu():
                self.device_type = "tpu"
783
784
            elif is_cpu():
                self.device_type = "cpu"
785
786
            elif is_xpu():
                self.device_type = "xpu"
787
            else:
788
789
790
                # We don't call torch.cuda.is_available() here to
                # avoid initializing CUDA before workers are forked
                self.device_type = "cuda"
791
792
793
794
795
796
797
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
        if self.device_type in ["neuron"]:
            self.device = torch.device("cpu")
798
799
        elif self.device_type in ["tpu"]:
            self.device = None
800
801
802
803
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

804

805
806
807
808
809
810
811
812
813
814
815
816
817
class SpeculativeConfig:
    """Configuration for speculative decoding.

    The configuration is currently specialized to draft-model speculative
    decoding with top-1 proposals.
    """

    @staticmethod
    def maybe_create_spec_config(
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        target_dtype: str,
        speculative_model: Optional[str],
818
        speculative_draft_tensor_parallel_size: Optional[int],
819
        num_speculative_tokens: Optional[int],
820
821
822
        speculative_max_model_len: Optional[int],
        enable_chunked_prefill: bool,
        use_v2_block_manager: bool,
823
        speculative_disable_by_batch_size: Optional[int],
824
825
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
    ) -> Optional["SpeculativeConfig"]:
        """Create a SpeculativeConfig if possible, else return None.

        This function attempts to create a SpeculativeConfig object based on the
        provided parameters. If the necessary conditions are met, it returns an
        instance of SpeculativeConfig. Otherwise, it returns None.

        Args:
            target_model_config (ModelConfig): The configuration of the target
                model.
            target_parallel_config (ParallelConfig): The parallel configuration
                for the target model.
            target_dtype (str): The data type used for the target model.
            speculative_model (Optional[str]): The name of the speculative
                model, if provided.
841
842
            speculative_draft_tensor_parallel_size (Optional[int]): The degree
                of the tensor parallelism for the draft model.
843
            num_speculative_tokens (Optional[int]): The number of speculative
844
845
                tokens, if provided. Will default to the number in the draft
                model config if present, otherwise is required.
846
847
848
849
850
851
852
853
854
            speculative_max_model_len (Optional[int]): The maximum model len of
                the speculative model. Used when testing the ability to skip
                speculation for some sequences.
            enable_chunked_prefill (bool): Whether vLLM is configured to use
                chunked prefill or not. Used for raising an error since its not
                yet compatible with spec decode.
            use_v2_block_manager (bool): Whether vLLM is configured to use the
                v2 block manager or not. Used for raising an error since the v2
                block manager is required with spec decode.
855
856
857
            speculative_disable_by_batch_size (Optional[int]): Disable
                speculative decoding for new incoming requests when the number
                of enqueue requests  is larger than this value, if provided.
858
859
860
861
            ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
                window, if provided.
            ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
                window, if provided.
862
863
864
865
866
867

        Returns:
            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                the necessary conditions are met, else None.
        """

868
869
870
871
        if speculative_model is None:
            if num_speculative_tokens is not None:
                raise ValueError("num_speculative_tokens was provided without "
                                 "speculative_model.")
872
873
            return None

874
875
876
877
878
879
        if (speculative_disable_by_batch_size is not None
                and speculative_disable_by_batch_size < 2):
            raise ValueError("Expect the batch size threshold of disabling "
                             "speculative decoding is > 1, but got "
                             f"{speculative_disable_by_batch_size=}")

880
881
882
883
884
885
886
887
888
889
        if enable_chunked_prefill:
            raise ValueError(
                "Speculative decoding and chunked prefill are "
                f"currently mutually exclusive ({enable_chunked_prefill=}).")

        if not use_v2_block_manager:
            raise ValueError(
                "Speculative decoding requires usage of the V2 "
                "block manager. Enable it with --use-v2-block-manager.")

890
891
892
893
894
895
        # TODO: The user should be able to specify revision/quantization/max
        # model len for the draft model. It is not currently supported.
        draft_revision = None
        draft_code_revision = None
        draft_quantization = None

896
897
        if speculative_model == "[ngram]":
            if ngram_prompt_lookup_min is None:
898
899
900
901
902
903
904
905
                ngram_prompt_lookup_min = 1
            if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
                raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
            if ngram_prompt_lookup_min < 1:
                raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
            if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
                raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
                                 f"larger than {ngram_prompt_lookup_max=}")
906

907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
            # TODO: current we still need extract vocab_size from target model
            # config, in future, we may try refactor it out, and set
            # draft related config as None here.
            draft_model_config = target_model_config
            draft_parallel_config = target_parallel_config
        else:
            ngram_prompt_lookup_max = 0
            ngram_prompt_lookup_min = 0
            draft_model_config = ModelConfig(
                model=speculative_model,
                tokenizer=target_model_config.tokenizer,
                tokenizer_mode=target_model_config.tokenizer_mode,
                trust_remote_code=target_model_config.trust_remote_code,
                dtype=target_model_config.dtype,
                seed=target_model_config.seed,
                revision=draft_revision,
                code_revision=draft_code_revision,
                tokenizer_revision=target_model_config.tokenizer_revision,
                max_model_len=None,
                quantization=draft_quantization,
                enforce_eager=target_model_config.enforce_eager,
928
929
                max_seq_len_to_capture=target_model_config.
                max_seq_len_to_capture,
930
931
932
                max_logprobs=target_model_config.max_logprobs,
            )

933
934
            draft_hf_config = draft_model_config.hf_config
            if (draft_hf_config.model_type == "mlp_speculator"
935
936
937
938
939
940
                    and target_parallel_config.world_size != 1):
                # MLPSpeculator TP support will be added very soon
                raise ValueError(
                    "Speculative decoding with mlp_speculator models does not "
                    "yet support distributed inferencing (TP > 1).")

941
942
943
944
945
            if (num_speculative_tokens is not None
                    and hasattr(draft_hf_config, "num_lookahead_tokens")):
                draft_hf_config.num_lookahead_tokens = num_speculative_tokens

            n_predict = getattr(draft_hf_config, "n_predict", None)
946
947
948
949
950
951
952
953
954
955
956
957
            if n_predict is not None:
                if num_speculative_tokens is None:
                    # Default to max value defined in draft model config.
                    num_speculative_tokens = n_predict
                elif num_speculative_tokens > n_predict:
                    # Verify provided value doesn't exceed the maximum
                    # supported by the draft model.
                    raise ValueError(
                        "Expected both speculative_model and "
                        "num_speculative_tokens to be provided, but found "
                        f"{speculative_model=} and {num_speculative_tokens=}.")

958
959
960
961
962
963
964
965
966
            draft_model_config.max_model_len = (
                SpeculativeConfig._maybe_override_draft_max_model_len(
                    speculative_max_model_len,
                    draft_model_config.max_model_len,
                    target_model_config.max_model_len,
                ))

            draft_parallel_config = (
                SpeculativeConfig.create_draft_parallel_config(
967
968
                    target_parallel_config,
                    speculative_draft_tensor_parallel_size))
969

970
971
972
973
974
975
        if num_speculative_tokens is None:
            raise ValueError(
                "num_speculative_tokens must be provided with "
                "speculative_model unless the draft model config contains an "
                "n_predict parameter.")

976
977
978
979
        return SpeculativeConfig(
            draft_model_config,
            draft_parallel_config,
            num_speculative_tokens,
980
            speculative_disable_by_batch_size,
981
982
            ngram_prompt_lookup_max,
            ngram_prompt_lookup_min,
983
984
        )

985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
    @staticmethod
    def _maybe_override_draft_max_model_len(
        speculative_max_model_len: Optional[int],
        draft_max_model_len: int,
        target_max_model_len: int,
    ) -> int:
        """Determine the max sequence len for the draft model. This is usually
        the draft_max_model_len, but may be the target_max_model_len if it is
        less than the draft_max_model_len, or may be speculative_max_model_len
        if it is specified.

        This is necessary so that sequences do not exceed the capacity of the
        draft model or the target model.

        speculative_max_model_len is mainly used for testing that sequences can
        skip speculation.
        """

        if speculative_max_model_len is not None:

            if speculative_max_model_len > draft_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {draft_max_model_len=}")

            if speculative_max_model_len > target_max_model_len:
                raise ValueError(f"{speculative_max_model_len=} cannot be "
                                 f"larger than {target_max_model_len=}")

            return speculative_max_model_len

        return min(
            draft_max_model_len,
            target_max_model_len,
        )

1020
1021
    @staticmethod
    def create_draft_parallel_config(
1022
1023
1024
        target_parallel_config: ParallelConfig,
        speculative_draft_tensor_parallel_size: Optional[int]
    ) -> ParallelConfig:
1025
1026
        """Create a parallel config for use by the draft worker.

1027
        This is mostly a copy of the target parallel config, except the tp_size.
1028
        """
1029
1030
1031
1032
1033
1034
1035
1036
1037
        if speculative_draft_tensor_parallel_size is None:
            speculative_draft_tensor_parallel_size = \
                  target_parallel_config.tensor_parallel_size
        elif speculative_draft_tensor_parallel_size != 1:
            # TODO(wooyeon): allow tp values larger than 1
            raise ValueError(
                f"{speculative_draft_tensor_parallel_size=} cannot be"
                f"other value than 1")

1038
1039
1040
        draft_parallel_config = ParallelConfig(
            pipeline_parallel_size=target_parallel_config.
            pipeline_parallel_size,
1041
            tensor_parallel_size=speculative_draft_tensor_parallel_size,
1042
1043
            distributed_executor_backend=target_parallel_config.
            distributed_executor_backend,
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
            max_parallel_loading_workers=target_parallel_config.
            max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.
            disable_custom_all_reduce,
            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
            ray_workers_use_nsight=target_parallel_config.
            ray_workers_use_nsight,
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

    def __init__(
        self,
        draft_model_config: ModelConfig,
        draft_parallel_config: ParallelConfig,
        num_speculative_tokens: int,
1061
1062
1063
        speculative_disable_by_batch_size: Optional[int],
        ngram_prompt_lookup_max: Optional[int],
        ngram_prompt_lookup_min: Optional[int],
1064
1065
1066
1067
1068
1069
1070
1071
    ):
        """Create a SpeculativeConfig object.

        Args:
            draft_model_config: ModelConfig for the draft model.
            draft_parallel_config: ParallelConfig for the draft model.
            num_speculative_tokens: The number of tokens to sample from the
                draft model before scoring with the target model.
1072
1073
1074
1075
1076
            speculative_disable_by_batch_size: Disable speculative
                decoding for new incoming requests when the number of
                enqueue requests is larger than this value.
            ngram_prompt_lookup_max: Max size of ngram token window.
            ngram_prompt_lookup_min: Min size of ngram token window.
1077
1078
1079
1080
        """
        self.draft_model_config = draft_model_config
        self.draft_parallel_config = draft_parallel_config
        self.num_speculative_tokens = num_speculative_tokens
1081
1082
1083
1084
        self.speculative_disable_by_batch_size = \
            speculative_disable_by_batch_size
        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107

        self._verify_args()

    def _verify_args(self) -> None:
        if self.num_speculative_tokens <= 0:
            raise ValueError("Expected num_speculative_tokens to be greater "
                             f"than zero ({self.num_speculative_tokens}).")

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
                self.draft_parallel_config)

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def __repr__(self) -> str:
1108
1109
1110
1111
        if self.ngram_prompt_lookup_max > 0:
            draft_model = "[ngram]"
        else:
            draft_model = self.draft_model_config.model
1112
1113
1114
1115
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


1116
1117
1118
1119
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
1120
    fully_sharded_loras: bool = False
1121
1122
1123
1124
1125
    max_cpu_loras: Optional[int] = None
    lora_dtype: Optional[torch.dtype] = None
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256
1126
    long_lora_scaling_factors: Optional[Tuple[float]] = None
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146

    def __post_init__(self):
        # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
        possible_max_ranks = (8, 16, 32, 64)
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
1147
                f"max_loras ({self.max_loras})")
1148
1149
1150
1151
1152
1153

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
1154
1155
1156
1157
        if model_config.quantization and model_config.quantization not in [
                "awq", "gptq"
        ]:
            # TODO support marlin and squeezellm
1158
1159
            logger.warning("%s quantization is not tested with LoRA yet.",
                           model_config.quantization)
1160
1161
1162
1163
1164
1165
1166

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
        if scheduler_config.max_num_batched_tokens > 65528:
            raise ValueError(
                "Due to limitations of the custom LoRA CUDA kernel, "
                "max_num_batched_tokens must be <= 65528 when "
                "LoRA is enabled.")
1167
1168
        if scheduler_config.chunked_prefill_enabled:
            raise ValueError("LoRA is not supported with chunked prefill yet.")
1169
1170


1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
@dataclass
class VisionLanguageConfig:
    """Configs the input data format and how models should run for
    vision language models."""

    class ImageInputType(enum.Enum):
        """Image input type into the vision language model.

        An image roughly goes through the following transformation:
        Raw image --> pixel values --> image features --> image embeddings.

        The difference between different image input types is where the
        image encoder (pixel values --> image features) is run.
        Different image input types also correspond to different tensor shapes.

        For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
        IMAGE_FEATURES: (1, 576, 1024).
        """
        PIXEL_VALUES = enum.auto()
        IMAGE_FEATURES = enum.auto()

    image_input_type: ImageInputType
    # The input id corresponding to image token.
    image_token_id: int
    # Used for running `run_prefill_max_token`.
    # For models that support varying resolution, this corresponds to
    # worst case scenario (biggest supported resolution).
    image_input_shape: tuple
    image_feature_size: int
1200
1201
1202
    # The image processor to load from HuggingFace
    image_processor: Optional[str]
    image_processor_revision: Optional[str]
1203
1204

    @classmethod
1205
    def get_image_input_enum_type(cls, value: str) -> ImageInputType:
1206
1207
1208
1209
1210
1211
1212
1213
        """Get the image input type from a string."""
        try:
            return cls.ImageInputType[value.upper()]
        except KeyError as e:
            raise ValueError(f"{value} is not a valid choice. "
                             f"Expecting to choose from "
                             f"{[x.name for x in cls.ImageInputType]}.") from e

1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
    #TODO(ywang96): make this a cached property once we refactor the
    # VisionLanguageConfig class.
    def get_image_token_text(
            self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
        """Get the image token placeholder text to be inserted into the 
        text prompt and the string representation of the image token id.
        """
        image_token_str = tokenizer.decode(self.image_token_id)
        return image_token_str * self.image_feature_size, image_token_str

1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
    def as_cli_args_dict(self) -> Dict[str, Any]:
        """Flatten vision language config to pure args.

        Compatible with what llm entrypoint expects.
        """
        result: Dict[str, Any] = {}
        for f in fields(self):
            value = getattr(self, f.name)
            if isinstance(value, enum.Enum):
                result[f.name] = value.name.lower()
            elif isinstance(value, tuple):
                result[f.name] = ",".join([str(item) for item in value])
            else:
                result[f.name] = value

        result["disable_image_processor"] = self.image_processor is None

        return result

1243

1244
1245
1246
1247
1248
1249
1250
1251
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

1252
_ROCM_NOT_SUPPORTED_DTYPE: List[str] = []  #
1253

1254
1255
1256

def _get_and_verify_dtype(
    config: PretrainedConfig,
1257
    dtype: Union[str, torch.dtype],
1258
1259
1260
1261
1262
1263
1264
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

1265
1266
1267
1268
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
Woosuk Kwon's avatar
Woosuk Kwon committed
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
                if config.model_type == "gemma2":
                    logger.info(
                        "For Gemma 2, we downcast float32 to bfloat16 instead "
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16.")
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
1279
1280
            else:
                torch_dtype = config_dtype
1281
        else:
1282
1283
1284
1285
1286
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
1287
    else:
1288
        raise ValueError(f"Unknown dtype: {dtype}")
1289
1290
1291
1292
1293

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
1294
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
1295
1296
1297
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
1298
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
1299
1300
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
1301
            # Casting between float16 and bfloat16 is allowed with a warning.
1302
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
1303
1304

    return torch_dtype
1305
1306
1307
1308
1309


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
1310
1311
    disable_sliding_window: bool,
    sliding_window_len: Optional[int],
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
1322
1323
        # ChatGLM2
        "seq_length",
1324
1325
        # Command-R
        "model_max_length",
1326
1327
1328
1329
1330
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
1331
    # Choose the smallest "max_length" from the possible keys.
1332
    max_len_key = None
1333
    for key in possible_keys:
1334
1335
1336
1337
1338
        max_len = getattr(hf_config, key, None)
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348

    # If sliding window is manually disabled, max_length should be less
    # than the sliding window length in the model config.
    if disable_sliding_window and sliding_window_len is not None:
        max_len_key = "sliding_window" \
            if sliding_window_len < derived_max_model_len else max_len_key
        derived_max_model_len = min(derived_max_model_len, sliding_window_len)

    # If none of the keys were found in the config, use a default and
    # log a warning.
1349
    if derived_max_model_len == float("inf"):
1350
1351
1352
1353
1354
1355
1356
1357
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
1358
            "%s. Assuming the model's maximum length is %d.", possible_keys,
1359
            default_max_len)
1360
        derived_max_model_len = default_max_len
1361

1362
    rope_scaling = getattr(hf_config, "rope_scaling", None)
1363
1364
1365
1366
    # The correct one should be "longrope", kept "su" here
    # to be backward compatible
    if rope_scaling is not None and rope_scaling["type"] != "su" \
        and rope_scaling["type"] != "longrope":
1367
1368
1369
1370
1371
1372
1373
        if disable_sliding_window:
            # TODO(robertgshaw): Find a model that supports rope_scaling
            # with sliding window to see if this case should be allowed.
            raise NotImplementedError(
                "Disabling sliding window is not supported for models "
                "with rope_scaling. Please raise an issue so we can "
                "investigate.")
1374
1375
        assert "factor" in rope_scaling
        scaling_factor = rope_scaling["factor"]
Antoni Baum's avatar
Antoni Baum committed
1376
1377
1378
        if rope_scaling["type"] == "yarn":
            derived_max_model_len = rope_scaling[
                "original_max_position_embeddings"]
1379
1380
        derived_max_model_len *= scaling_factor

1381
1382
    # If the user specified a max length, make sure it is smaller than the
    # derived length from the HF model config.
1383
    if max_model_len is None:
1384
        max_model_len = int(derived_max_model_len)
1385
    elif max_model_len > derived_max_model_len:
1386
1387
1388
1389
1390
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, "model_max_length", None)
        if model_max_length is not None and max_model_len <= model_max_length:
1391
1392
1393
1394
1395
1396
1397
            if disable_sliding_window:
                # TODO(robertgshaw): Find a model that has model_max_length
                # with sliding window to see if this case should be allowed.
                raise NotImplementedError(
                    "Disabling sliding window is not supported for models "
                    "model_max_length in the config. Please raise an issue "
                    "so we can investigate.")
1398
1399
1400
1401
1402
1403
1404
1405
1406
            pass
        else:
            raise ValueError(
                f"User-specified max_model_len ({max_model_len}) is greater "
                "than the derived max_model_len "
                f"({max_len_key}={derived_max_model_len} or model_max_length="
                f"{model_max_length} in model's config.json). This may lead "
                "to incorrect model outputs or CUDA errors. Make sure the "
                "value is correct and within the model context size.")
1407
    return int(max_model_len)
1408
1409


1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
def get_served_model_name(model: str,
                          served_model_name: Optional[Union[str, List[str]]]):
    """
    If the input is a non-empty list, the first model_name in 
    `served_model_name` is taken. 
    If the input is a non-empty string, it is used directly. 
    For cases where the input is either an empty string or an 
    empty list, the fallback is to use `self.model`.
    """
    if not served_model_name:
        return model
    if isinstance(served_model_name, list):
        return served_model_name[0]
    return served_model_name


1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
@dataclass
class DecodingConfig:
    """Dataclass which contains the decoding strategy of the engine"""

    # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
    guided_decoding_backend: str = 'outlines'

    def __post_init__(self):
        valid_guided_backends = ['outlines', 'lm-format-enforcer']
        backend = self.guided_decoding_backend
        if backend not in valid_guided_backends:
            raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                             f"must be one of {valid_guided_backends}")


1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
@dataclass
class ObservabilityConfig:
    """Configuration for observability."""
    otlp_traces_endpoint: Optional[str] = None

    def __post_init__(self):
        if not is_otel_installed() and self.otlp_traces_endpoint is not None:
            raise ValueError("OpenTelemetry packages must be installed before "
                             "configuring 'otlp_traces_endpoint'")


1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
@dataclass(frozen=True)
class EngineConfig:
    """Dataclass which contains all engine-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    model_config: ModelConfig
    cache_config: CacheConfig
    parallel_config: ParallelConfig
    scheduler_config: SchedulerConfig
    device_config: DeviceConfig
1463
    load_config: LoadConfig
1464
1465
1466
    lora_config: Optional[LoRAConfig]
    vision_language_config: Optional[VisionLanguageConfig]
    speculative_config: Optional[SpeculativeConfig]
1467
    decoding_config: Optional[DecodingConfig]
1468
    observability_config: Optional[ObservabilityConfig]
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485

    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
        self.model_config.verify_with_parallel_config(self.parallel_config)
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)

    def to_dict(self):
        """Return the configs as a dictionary, for use in **kwargs.
        """
        return dict(
            (field.name, getattr(self, field.name)) for field in fields(self))