config.py 43.3 KB
Newer Older
1
import enum
2
import json
3
import os
4
from dataclasses import dataclass, field, fields
5
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
6
7

import torch
8
from packaging.version import Version
9
from transformers import PretrainedConfig
10

Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
from vllm.transformers_utils.config import get_config, get_hf_text_config
14
15
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
                        is_neuron)
16

17
18
19
if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

20
    from vllm.model_executor.model_loader.loader import BaseModelLoader
21

22
23
logger = init_logger(__name__)

24
25
26
27
# If true, will load models from ModelScope instead of Hugging Face Hub.
VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE",
                                     "False").lower() == "true"

28
_GB = 1 << 30
29

30
31

class ModelConfig:
32
33
34
35
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
36
        tokenizer: Name or path of the huggingface tokenizer to use.
37
38
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
            available, and "slow" will always use the slow tokenizer.
39
40
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
41
42
43
44
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
45
46
47
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
48
        code_revision: The specific revision to use for the model code on
49
            Hugging Face Hub. It can be a branch name, a tag name, or a
50
            commit id. If unspecified, will use the default version.
51
52
53
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
54
55
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
56
57
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
58
59
        quantization_param_path: Path to JSON file containing scaling factors.
            Used to load KV cache scaling factors into the model when KV cache
60
61
            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
            be used to load activation and weight scaling factors when the
62
            model dtype is FP8_E4M3 on ROCm.
63
64
65
66
67
68
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode.
69
    """
70
71
72
73

    def __init__(
        self,
        model: str,
74
75
        tokenizer: str,
        tokenizer_mode: str,
76
        trust_remote_code: bool,
77
        dtype: Union[str, torch.dtype],
78
        seed: int,
79
        revision: Optional[str] = None,
80
        code_revision: Optional[str] = None,
81
        tokenizer_revision: Optional[str] = None,
82
        max_model_len: Optional[int] = None,
83
        quantization: Optional[str] = None,
84
        quantization_param_path: Optional[str] = None,
85
86
        enforce_eager: bool = False,
        max_context_len_to_capture: Optional[int] = None,
87
        max_logprobs: int = 5,
88
89
    ) -> None:
        self.model = model
90
        self.tokenizer = tokenizer
91
        self.tokenizer_mode = tokenizer_mode
92
        self.trust_remote_code = trust_remote_code
93
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
94
        self.revision = revision
95
        self.code_revision = code_revision
96
        self.tokenizer_revision = tokenizer_revision
97
        self.quantization = quantization
98
        self.quantization_param_path = quantization_param_path
99
100
        self.enforce_eager = enforce_eager
        self.max_context_len_to_capture = max_context_len_to_capture
101
        self.max_logprobs = max_logprobs
102

103
104
        self.hf_config = get_config(self.model, trust_remote_code, revision,
                                    code_revision)
105
106
107
        self.hf_text_config = get_hf_text_config(self.hf_config)
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
        self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
108
                                                     max_model_len)
109
        self._verify_tokenizer_mode()
110
        self._verify_quantization()
111
        self._verify_cuda_graph()
112
113
114
115
116
117
118
119

    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
        if tokenizer_mode not in ["auto", "slow"]:
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                "either 'auto' or 'slow'.")
        self.tokenizer_mode = tokenizer_mode
120

121
    def _verify_quantization(self) -> None:
122
123
        supported_quantization = [*QUANTIZATION_METHODS]
        rocm_supported_quantization = ["gptq", "squeezellm"]
124
125
126
127
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
128
129
130
131
132
133
134
135
136
137
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
            # compat: autogptq >=0.8.0 use checkpoint_format: str
            # compat: autogptq <=0.7.1 is_marlin_format: bool
            is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
                                or quant_cfg.get("is_marlin_format", False))

            # Use marlin if the GPTQ model is serialized in marlin format.
            if quant_method == "gptq" and is_format_marlin:
138
139
                logger.info("The model is serialized in Marlin format. "
                            "Using Marlin kernel.")
140
                quant_method = "marlin"
141
                if self.quantization == "gptq":
142
                    self.quantization = quant_method
143

144
            if self.quantization is None:
145
146
                self.quantization = quant_method
            elif self.quantization != quant_method:
147
148
                raise ValueError(
                    "Quantization method specified in the model config "
149
                    f"({quant_method}) does not match the quantization "
150
151
152
153
154
155
156
157
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
158
            if is_hip(
159
            ) and self.quantization not in rocm_supported_quantization:
160
                raise ValueError(
161
162
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm.")
163
164
165
166
167
            if self.quantization != "marlin":
                logger.warning(
                    f"{self.quantization} quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.")
168

169
170
171
172
173
174
    def _verify_cuda_graph(self) -> None:
        if self.max_context_len_to_capture is None:
            self.max_context_len_to_capture = self.max_model_len
        self.max_context_len_to_capture = min(self.max_context_len_to_capture,
                                              self.max_model_len)

175
176
177
178
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
179
        total_num_attention_heads = self.hf_text_config.num_attention_heads
180
181
182
183
184
185
186
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

187
        total_num_hidden_layers = self.hf_text_config.num_hidden_layers
188
189
190
191
192
193
194
        pipeline_parallel_size = parallel_config.pipeline_parallel_size
        if total_num_hidden_layers % pipeline_parallel_size != 0:
            raise ValueError(
                f"Total number of hidden layers ({total_num_hidden_layers}) "
                "must be divisible by pipeline parallel size "
                f"({pipeline_parallel_size}).")

195
    def get_sliding_window(self) -> Optional[int]:
196
197
198
199
200
201
        """Get the sliding window size, or None if disabled.
        """

        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
        # addition to sliding window size. We check if that field is present
        # and if it's False, return None.
202
203
        if (hasattr(self.hf_text_config, "use_sliding_window")
                and not self.hf_text_config.use_sliding_window):
204
            return None
205
        return getattr(self.hf_text_config, "sliding_window", None)
206
207

    def get_vocab_size(self) -> int:
208
        return self.hf_text_config.vocab_size
209

210
    def get_hidden_size(self) -> int:
211
        return self.hf_text_config.hidden_size
212
213

    def get_head_size(self) -> int:
214
215
        if hasattr(self.hf_text_config, "head_dim"):
            return self.hf_text_config.head_dim
216
        # FIXME(woosuk): This may not be true for all models.
217
218
        return (self.hf_text_config.hidden_size //
                self.hf_text_config.num_attention_heads)
219

220
221
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
222
        # For GPTBigCode & Falcon:
223
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
224
225
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
226
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
227
        new_decoder_arch_falcon = (
228
            self.hf_config.model_type in falcon_model_types
229
            and getattr(self.hf_config, "new_decoder_architecture", False))
230
        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
231
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
232
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
233
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
234
            return 1
235

236
237
238
239
240
        # For DBRX and MPT
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            return getattr(self.hf_config.attn_config, "kv_n_heads",
                           self.hf_config.num_attention_heads)

241
242
243
244
245
246
247
248
249
250
        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
251
            num_kv_heads = getattr(self.hf_text_config, attr, None)
252
253
254
255
256
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
257
        return self.hf_text_config.num_attention_heads
258
259
260
261
262
263
264
265
266
267

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
268
269

    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
270
        total_num_hidden_layers = self.hf_text_config.num_hidden_layers
271
272
273
274
        return total_num_hidden_layers // parallel_config.pipeline_parallel_size


class CacheConfig:
275
276
277
278
279
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
280
            vLLM execution.
281
        swap_space: Size of the CPU swap space per GPU (in GiB).
282
        cache_dtype: Data type for kv cache storage.
283
        num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
284
            profiled num_gpu_blocks if specified. Does nothing if None.
285
    """
286

287
288
289
290
291
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        swap_space: int,
292
        cache_dtype: str,
293
        num_gpu_blocks_override: Optional[int] = None,
294
        sliding_window: Optional[int] = None,
295
        enable_prefix_caching: bool = False,
296
297
298
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
299
        self.swap_space_bytes = swap_space * _GB
300
        self.num_gpu_blocks_override = num_gpu_blocks_override
301
        self.cache_dtype = cache_dtype
302
        self.sliding_window = sliding_window
303
        self.enable_prefix_caching = enable_prefix_caching
304
        self._verify_args()
305
        self._verify_cache_dtype()
306
307
308
309
310

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

311
    def metrics_info(self):
312
313
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
314
315
        return {key: str(value) for key, value in self.__dict__.items()}

316
317
318
319
320
321
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

322
323
324
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
325
326
327
328
329
330
331
        elif self.cache_dtype == "fp8":
            if not is_hip():
                nvcc_cuda_version = get_nvcc_cuda_version()
                if nvcc_cuda_version < Version("11.8"):
                    raise ValueError(
                        "FP8 is not supported when cuda version is"
                        "lower than 11.8.")
332
            logger.info(
333
334
335
336
337
338
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
                "But it may cause slight accuracy drop without scaling "
                "factors. FP8_E5M2 (without scaling) is only supported on "
                "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
                "is instead supported for common inference criteria.")
339
340
341
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

342
343
344
345
346
347
348
349
350
351
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

352
353
354
        msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
               f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
               "allocated for the swap space.")
355
356
357
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
358
            logger.warning("Possibly too large swap space. " + msg)
359

360

361
362
363
@dataclass
class TokenizerPoolConfig:
    """Configuration for the tokenizer pool.
364

365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    Args:
        pool_size: Number of tokenizer workers in the pool.
        pool_type: Type of the pool.
        extra_config: Additional config for the pool.
            The way the config will be used depends on the
            pool type.
    """
    pool_size: int
    pool_type: str
    extra_config: dict

    def __post_init__(self):
        if self.pool_type not in ("ray", ):
            raise ValueError(f"Unknown pool type: {self.pool_type}")
        if not isinstance(self.extra_config, dict):
            raise ValueError("extra_config must be a dictionary.")

    @classmethod
    def create_config(
        cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
        tokenizer_pool_extra_config: Optional[Union[str, dict]]
    ) -> Optional["TokenizerPoolConfig"]:
        """Create a TokenizerPoolConfig from the given parameters.
388

389
        If tokenizer_pool_size is 0, return None.
390

391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
        Args:
            tokenizer_pool_size: Number of tokenizer workers in the pool.
            tokenizer_pool_type: Type of the pool.
            tokenizer_pool_extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type. This can be a JSON string (will be parsed).
        """
        if tokenizer_pool_size:
            if isinstance(tokenizer_pool_extra_config, str):
                tokenizer_pool_extra_config_parsed = json.loads(
                    tokenizer_pool_extra_config)
            else:
                tokenizer_pool_extra_config_parsed = (
                    tokenizer_pool_extra_config or {})
            tokenizer_pool_config = cls(tokenizer_pool_size,
                                        tokenizer_pool_type,
                                        tokenizer_pool_extra_config_parsed)
        else:
            tokenizer_pool_config = None
        return tokenizer_pool_config


413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
class LoadFormat(str, enum.Enum):
    AUTO = "auto"
    PT = "pt"
    SAFETENSORS = "safetensors"
    NPCACHE = "npcache"
    DUMMY = "dummy"
    TENSORIZER = "tensorizer"


@dataclass
class LoadConfig:
    """
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
            "tensorizer" will use CoreWeave's tensorizer library for
                fast weight loading.
    """

    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
    download_dir: Optional[str] = None
    model_loader_extra_config: Optional[Union[str, dict]] = field(
        default_factory=dict)

    def __post_init__(self):
        model_loader_extra_config = self.model_loader_extra_config or {}
        if isinstance(model_loader_extra_config, str):
            self.model_loader_extra_config = json.loads(
                model_loader_extra_config)
        self._verify_load_format()

    def _verify_load_format(self) -> None:
        if not isinstance(self.load_format, str):
            return

        load_format = self.load_format.lower()
        self.load_format = LoadFormat(load_format)

        rocm_not_supported_load_format: List[str] = []
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f for f in LoadFormat.__members__
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format '{load_format}' is not supported in ROCm. "
                f"Supported load formats are "
                f"{rocm_supported_load_format}")


472
class ParallelConfig:
473
474
475
476
477
478
479
480
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
        worker_use_ray: Whether to use Ray for model workers. Will be set to
            True if either pipeline_parallel_size or tensor_parallel_size is
            greater than 1.
zspo's avatar
zspo committed
481
482
483
        max_parallel_loading_workers: Maximum number of multiple batches
            when load model sequentially. To avoid RAM OOM when using tensor
            parallel and large models.
484
485
        disable_custom_all_reduce: Disable the custom all-reduce kernel and
            fall back to NCCL.
486
487
        tokenizer_pool_config: Config for the tokenizer pool.
            If None, will use synchronous tokenization.
488
489
        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
490
    """
491

492
493
494
495
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
496
        worker_use_ray: bool,
497
        max_parallel_loading_workers: Optional[int] = None,
498
        disable_custom_all_reduce: bool = False,
499
        tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
500
        ray_workers_use_nsight: bool = False,
501
        placement_group: Optional["PlacementGroup"] = None,
502
503
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
504
        self.tensor_parallel_size = tensor_parallel_size
505
        self.worker_use_ray = worker_use_ray
506
        self.max_parallel_loading_workers = max_parallel_loading_workers
507
        self.disable_custom_all_reduce = disable_custom_all_reduce
508
        self.tokenizer_pool_config = tokenizer_pool_config
509
        self.ray_workers_use_nsight = ray_workers_use_nsight
510
        self.placement_group = placement_group
511

512
        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
513
        if self.world_size > 1:
514
            self.worker_use_ray = True
515
516
517
518
519
520
        self._verify_args()

    def _verify_args(self) -> None:
        if self.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is not supported yet.")
521
522
523
524
525
526
527
528
529
530
531
        if not self.disable_custom_all_reduce and self.world_size > 1:
            if is_hip():
                self.disable_custom_all_reduce = True
                logger.info(
                    "Disabled the custom all-reduce kernel because it is not "
                    "supported on AMD GPUs.")
            elif self.pipeline_parallel_size > 1:
                self.disable_custom_all_reduce = True
                logger.info(
                    "Disabled the custom all-reduce kernel because it is not "
                    "supported with pipeline parallelism.")
532
533
534
        if self.ray_workers_use_nsight and not self.worker_use_ray:
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
535

536
537

class SchedulerConfig:
538
539
540
541
542
543
544
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
545
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
546
            and generated text).
547
548
549
550
551
        use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
        num_lookahead_slots: The number of slots to allocate per sequence per
            step, beyond the known token ids. This is used in speculative
            decoding to store KV activations of tokens which may or may not be
            accepted.
552
553
        delay_factor: Apply a delay (of delay factor multiplied by previous
            prompt latency) before scheduling next prompt.
554
555
        enable_chunked_prefill: If True, prefill requests can be chunked based
            on the remaining max_num_batched_tokens.
556
    """
557

558
559
560
561
562
    def __init__(
        self,
        max_num_batched_tokens: Optional[int],
        max_num_seqs: int,
        max_model_len: int,
563
        use_v2_block_manager: bool = False,
564
        num_lookahead_slots: int = 0,
565
        delay_factor: float = 0.0,
566
        enable_chunked_prefill: bool = False,
567
568
569
570
    ) -> None:
        if max_num_batched_tokens is not None:
            self.max_num_batched_tokens = max_num_batched_tokens
        else:
571
572
573
574
575
576
577
578
579
580
            if enable_chunked_prefill:
                # For chunked prefill, choose the well-tuned batch size.
                self.max_num_batched_tokens = 768
            else:
                # If max_model_len is too short, use 2048 as the default value
                # for higher throughput.
                self.max_num_batched_tokens = max(max_model_len, 2048)
        if enable_chunked_prefill:
            logger.info("Chunked prefill is enabled (EXPERIMENTAL).")

581
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
582
        self.max_model_len = max_model_len
583
        self.use_v2_block_manager = use_v2_block_manager
584
585
        self.num_lookahead_slots = num_lookahead_slots
        self.delay_factor = delay_factor
586
        self.chunked_prefill_enabled = enable_chunked_prefill
587

588
589
590
        self._verify_args()

    def _verify_args(self) -> None:
591
592
        if (self.max_num_batched_tokens < self.max_model_len
                and not self.chunked_prefill_enabled):
593
594
595
596
597
598
599
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
600

601
602
603
604
605
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
606

607
608
609
610
611
612
        if self.num_lookahead_slots < 0:
            raise ValueError(
                "num_lookahead_slots "
                f"({self.num_lookahead_slots}) must be greater than or "
                "equal to 0.")

613

614
615
class DeviceConfig:

616
617
618
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
619
            if is_neuron():
620
                self.device_type = "neuron"
621
622
            elif is_cpu():
                self.device_type = "cpu"
623
            else:
624
625
626
                # We don't call torch.cuda.is_available() here to
                # avoid initializing CUDA before workers are forked
                self.device_type = "cuda"
627
628
629
630
631
632
633
634
635
636
637
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
        if self.device_type in ["neuron"]:
            self.device = torch.device("cpu")
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

638

639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
class SpeculativeConfig:
    """Configuration for speculative decoding.

    The configuration is currently specialized to draft-model speculative
    decoding with top-1 proposals.
    """

    @staticmethod
    def maybe_create_spec_config(
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        target_dtype: str,
        speculative_model: Optional[str],
        num_speculative_tokens: Optional[int],
    ) -> Optional["SpeculativeConfig"]:
        """Create a SpeculativeConfig if possible, else return None.

        This function attempts to create a SpeculativeConfig object based on the
        provided parameters. If the necessary conditions are met, it returns an
        instance of SpeculativeConfig. Otherwise, it returns None.

        Args:
            target_model_config (ModelConfig): The configuration of the target
                model.
            target_parallel_config (ParallelConfig): The parallel configuration
                for the target model.
            target_dtype (str): The data type used for the target model.
            speculative_model (Optional[str]): The name of the speculative
                model, if provided.
            num_speculative_tokens (Optional[int]): The number of speculative
                tokens, if provided.

        Returns:
            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                the necessary conditions are met, else None.
        """

        if (speculative_model is None and num_speculative_tokens is None):
            return None

        if speculative_model is not None and num_speculative_tokens is None:
            raise ValueError(
                "Expected both speculative_model and "
                "num_speculative_tokens to be provided, but found "
                f"{speculative_model=} and {num_speculative_tokens=}.")

685
686
687
        assert (speculative_model is not None
                and num_speculative_tokens is not None)

688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
        # TODO: The user should be able to specify revision/quantization/max
        # model len for the draft model. It is not currently supported.
        draft_revision = None
        draft_code_revision = None
        draft_quantization = None
        draft_max_model_len = None

        draft_model_config = ModelConfig(
            model=speculative_model,
            tokenizer=target_model_config.tokenizer,
            tokenizer_mode=target_model_config.tokenizer_mode,
            trust_remote_code=target_model_config.trust_remote_code,
            dtype=target_model_config.dtype,
            seed=target_model_config.seed,
            revision=draft_revision,
            code_revision=draft_code_revision,
            tokenizer_revision=target_model_config.tokenizer_revision,
            max_model_len=draft_max_model_len,
            quantization=draft_quantization,
            enforce_eager=target_model_config.enforce_eager,
            max_context_len_to_capture=target_model_config.
            max_context_len_to_capture,
            max_logprobs=target_model_config.max_logprobs,
        )

        draft_parallel_config = (
            SpeculativeConfig.create_draft_parallel_config(
                target_parallel_config))

        return SpeculativeConfig(
            draft_model_config,
            draft_parallel_config,
            num_speculative_tokens,
        )

    @staticmethod
    def create_draft_parallel_config(
            target_parallel_config: ParallelConfig) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config. In the future the
        draft worker can have a different parallel strategy, e.g. TP=1.
        """
        draft_parallel_config = ParallelConfig(
            pipeline_parallel_size=target_parallel_config.
            pipeline_parallel_size,
            tensor_parallel_size=target_parallel_config.tensor_parallel_size,
            worker_use_ray=target_parallel_config.worker_use_ray,
            max_parallel_loading_workers=target_parallel_config.
            max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.
            disable_custom_all_reduce,
            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
            ray_workers_use_nsight=target_parallel_config.
            ray_workers_use_nsight,
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

    def __init__(
        self,
        draft_model_config: ModelConfig,
        draft_parallel_config: ParallelConfig,
        num_speculative_tokens: int,
    ):
        """Create a SpeculativeConfig object.

        Args:
            draft_model_config: ModelConfig for the draft model.
            draft_parallel_config: ParallelConfig for the draft model.
            num_speculative_tokens: The number of tokens to sample from the
                draft model before scoring with the target model.
        """
        self.draft_model_config = draft_model_config
        self.draft_parallel_config = draft_parallel_config
        self.num_speculative_tokens = num_speculative_tokens

        self._verify_args()

    def _verify_args(self) -> None:
        if self.num_speculative_tokens <= 0:
            raise ValueError("Expected num_speculative_tokens to be greater "
                             f"than zero ({self.num_speculative_tokens}).")

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
                self.draft_parallel_config)

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def __repr__(self) -> str:
        draft_model = self.draft_model_config.model
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
    max_cpu_loras: Optional[int] = None
    lora_dtype: Optional[torch.dtype] = None
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256

    def __post_init__(self):
        # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
        possible_max_ranks = (8, 16, 32, 64)
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
822
                f"max_loras ({self.max_loras})")
823
824
825
826
827
828

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
829
830
831
832
833
834
        if model_config.quantization and model_config.quantization not in [
                "awq", "gptq"
        ]:
            # TODO support marlin and squeezellm
            logger.warning(f"{model_config.quantization} quantization is not "
                           "tested with LoRA yet.")
835
836
837
838
839
840
841
842
843

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
        if scheduler_config.max_num_batched_tokens > 65528:
            raise ValueError(
                "Due to limitations of the custom LoRA CUDA kernel, "
                "max_num_batched_tokens must be <= 65528 when "
                "LoRA is enabled.")


844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
@dataclass
class VisionLanguageConfig:
    """Configs the input data format and how models should run for
    vision language models."""

    class ImageInputType(enum.Enum):
        """Image input type into the vision language model.

        An image roughly goes through the following transformation:
        Raw image --> pixel values --> image features --> image embeddings.

        The difference between different image input types is where the
        image encoder (pixel values --> image features) is run.
        Different image input types also correspond to different tensor shapes.

        For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
        IMAGE_FEATURES: (1, 576, 1024).
        """
        PIXEL_VALUES = enum.auto()
        IMAGE_FEATURES = enum.auto()

    image_input_type: ImageInputType
    # The input id corresponding to image token.
    image_token_id: int
    # Used for running `run_prefill_max_token`.
    # For models that support varying resolution, this corresponds to
    # worst case scenario (biggest supported resolution).
    image_input_shape: tuple
    image_feature_size: int

    @classmethod
    def get_image_input_enum_type(
            cls, value: str) -> "VisionLanguageConfig.ImageInputType":
        """Get the image input type from a string."""
        try:
            return cls.ImageInputType[value.upper()]
        except KeyError as e:
            raise ValueError(f"{value} is not a valid choice. "
                             f"Expecting to choose from "
                             f"{[x.name for x in cls.ImageInputType]}.") from e


886
887
888
889
890
891
892
893
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

894
895
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]

896
897
898

def _get_and_verify_dtype(
    config: PretrainedConfig,
899
    dtype: Union[str, torch.dtype],
900
901
902
903
904
905
906
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

907
908
909
910
911
912
913
914
915
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
                # Following the common practice, we use float16 for float32
                # models.
                torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
916
        else:
917
918
919
920
921
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
922
    else:
923
        raise ValueError(f"Unknown dtype: {dtype}")
924

925
926
927
928
929
    if is_hip() and torch_dtype == torch.float32:
        rocm_supported_dtypes = [
            k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
            if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
        ]
930
        raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
931
932
                         f"Supported dtypes are {rocm_supported_dtypes}")

933
934
935
936
937
938
939
940
941
    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
942
            # Casting between float16 and bfloat16 is allowed with a warning.
943
            logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
944
945

    return torch_dtype
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
961
962
        # ChatGLM2
        "seq_length",
963
964
        # Command-R
        "model_max_length",
965
966
967
968
969
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
970
    max_len_key = None
971
    for key in possible_keys:
972
973
974
975
976
        max_len = getattr(hf_config, key, None)
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
977
    if derived_max_model_len == float("inf"):
978
979
980
981
982
983
984
985
986
987
988
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
            f"{possible_keys}. Assuming the model's maximum length is "
            f"{default_max_len}.")
        derived_max_model_len = default_max_len
989

990
991
992
993
    rope_scaling = getattr(hf_config, "rope_scaling", None)
    if rope_scaling is not None:
        assert "factor" in rope_scaling
        scaling_factor = rope_scaling["factor"]
Antoni Baum's avatar
Antoni Baum committed
994
995
996
        if rope_scaling["type"] == "yarn":
            derived_max_model_len = rope_scaling[
                "original_max_position_embeddings"]
997
998
        derived_max_model_len *= scaling_factor

999
    if max_model_len is None:
1000
        max_model_len = int(derived_max_model_len)
1001
    elif max_model_len > derived_max_model_len:
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, "model_max_length", None)
        if model_max_length is not None and max_model_len <= model_max_length:
            pass
        else:
            raise ValueError(
                f"User-specified max_model_len ({max_model_len}) is greater "
                "than the derived max_model_len "
                f"({max_len_key}={derived_max_model_len} or model_max_length="
                f"{model_max_length} in model's config.json). This may lead "
                "to incorrect model outputs or CUDA errors. Make sure the "
                "value is correct and within the model context size.")
1016
    return int(max_model_len)
1017
1018


1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
@dataclass
class DecodingConfig:
    """Dataclass which contains the decoding strategy of the engine"""

    # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
    guided_decoding_backend: str = 'outlines'

    def __post_init__(self):
        valid_guided_backends = ['outlines', 'lm-format-enforcer']
        backend = self.guided_decoding_backend
        if backend not in valid_guided_backends:
            raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                             f"must be one of {valid_guided_backends}")


1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
@dataclass(frozen=True)
class EngineConfig:
    """Dataclass which contains all engine-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    model_config: ModelConfig
    cache_config: CacheConfig
    parallel_config: ParallelConfig
    scheduler_config: SchedulerConfig
    device_config: DeviceConfig
1045
    load_config: LoadConfig
1046
1047
1048
    lora_config: Optional[LoRAConfig]
    vision_language_config: Optional[VisionLanguageConfig]
    speculative_config: Optional[SpeculativeConfig]
1049
    decoding_config: Optional[DecodingConfig]
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066

    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
        self.model_config.verify_with_parallel_config(self.parallel_config)
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)

    def to_dict(self):
        """Return the configs as a dictionary, for use in **kwargs.
        """
        return dict(
            (field.name, getattr(self, field.name)) for field in fields(self))