config.py 42.2 KB
Newer Older
1
import enum
2
import json
3
import os
4
from dataclasses import dataclass, fields
5
from typing import TYPE_CHECKING, ClassVar, Optional, Union
6
7

import torch
8
from packaging.version import Version
9
from transformers import PretrainedConfig
10

Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.logger import init_logger
12
from vllm.transformers_utils.config import get_config, get_hf_text_config
13
14
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
                        is_neuron)
15

16
17
18
if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

19
20
logger = init_logger(__name__)

21
_GB = 1 << 30
22

23
24

class ModelConfig:
25
26
27
28
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
29
        tokenizer: Name or path of the huggingface tokenizer to use.
30
31
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
            available, and "slow" will always use the slow tokenizer.
32
33
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
34
35
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
36
37
38
39
40
41
42
43
44
45
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
46
47
48
49
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
50
51
52
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
53
        code_revision: The specific revision to use for the model code on
54
            Hugging Face Hub. It can be a branch name, a tag name, or a
55
            commit id. If unspecified, will use the default version.
56
57
58
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
59
60
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
61
62
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
63
64
65
66
67
68
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode.
69
    """
70
71
72
73

    def __init__(
        self,
        model: str,
74
75
        tokenizer: str,
        tokenizer_mode: str,
76
        trust_remote_code: bool,
77
        download_dir: Optional[str],
78
        load_format: str,
79
        dtype: Union[str, torch.dtype],
80
        seed: int,
81
        revision: Optional[str] = None,
82
        code_revision: Optional[str] = None,
83
        tokenizer_revision: Optional[str] = None,
84
        max_model_len: Optional[int] = None,
85
        quantization: Optional[str] = None,
86
87
        enforce_eager: bool = False,
        max_context_len_to_capture: Optional[int] = None,
88
        max_logprobs: int = 5,
89
90
    ) -> None:
        self.model = model
91
        self.tokenizer = tokenizer
92
        self.tokenizer_mode = tokenizer_mode
93
        self.trust_remote_code = trust_remote_code
94
        self.download_dir = download_dir
95
        self.load_format = load_format
96
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
97
        self.revision = revision
98
        self.code_revision = code_revision
99
        self.tokenizer_revision = tokenizer_revision
100
        self.quantization = quantization
101
102
        self.enforce_eager = enforce_eager
        self.max_context_len_to_capture = max_context_len_to_capture
103
        self.max_logprobs = max_logprobs
104

105
106
107
        if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
            # download model from ModelScope hub,
            # lazy import so that modelscope is not required for normal use.
108
109
            # pylint: disable=C.
            from modelscope.hub.snapshot_download import snapshot_download
110

111
112
113
114
115
116
            if not os.path.exists(model):
                model_path = snapshot_download(model_id=model,
                                               cache_dir=download_dir,
                                               revision=revision)
            else:
                model_path = model
117
118
119
120
            self.model = model_path
            self.download_dir = model_path
            self.tokenizer = model_path

121
122
        self.hf_config = get_config(self.model, trust_remote_code, revision,
                                    code_revision)
123
124
125
        self.hf_text_config = get_hf_text_config(self.hf_config)
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
        self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
126
                                                     max_model_len)
127
        self._verify_load_format()
128
        self._verify_tokenizer_mode()
129
        self._verify_quantization()
130
        self._verify_cuda_graph()
131

132
133
    def _verify_load_format(self) -> None:
        load_format = self.load_format.lower()
134
135
136
        supported_load_format = [
            "auto", "pt", "safetensors", "npcache", "dummy"
        ]
kliuae's avatar
kliuae committed
137
        rocm_not_supported_load_format = []
138
        if load_format not in supported_load_format:
139
140
141
            raise ValueError(
                f"Unknown load format: {self.load_format}. Must be one of "
                "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
kliuae's avatar
kliuae committed
142
143
144
145
146
147
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f for f in supported_load_format
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
148
                f"load format '{load_format}' is not supported in ROCm. "
kliuae's avatar
kliuae committed
149
150
                f"Supported load format are "
                f"{rocm_supported_load_format}")
151

152
        # TODO: Remove this check once HF updates the pt weights of Mixtral.
153
        architectures = getattr(self.hf_config, "architectures", [])
Roy's avatar
Roy committed
154
155
156
157
        if "MixtralForCausalLM" in architectures and load_format == "pt":
            raise ValueError(
                "Currently, the 'pt' format is not supported for Mixtral. "
                "Please use the 'safetensors' format instead. ")
158
159
        self.load_format = load_format

160
161
162
163
164
165
166
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
        if tokenizer_mode not in ["auto", "slow"]:
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                "either 'auto' or 'slow'.")
        self.tokenizer_mode = tokenizer_mode
167

168
    def _verify_quantization(self) -> None:
169
170
        supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
        rocm_not_supported_quantization = ["awq", "marlin"]
171
172
173
174
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
175
176
177
178
179
180
181
182
183
184
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is not None:
            quant_method = quant_cfg.get("quant_method", "").lower()
            # compat: autogptq >=0.8.0 use checkpoint_format: str
            # compat: autogptq <=0.7.1 is_marlin_format: bool
            is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
                                or quant_cfg.get("is_marlin_format", False))

            # Use marlin if the GPTQ model is serialized in marlin format.
            if quant_method == "gptq" and is_format_marlin:
185
186
                logger.info("The model is serialized in Marlin format. "
                            "Using Marlin kernel.")
187
                quant_method = "marlin"
188
                if self.quantization == "gptq":
189
                    self.quantization = quant_method
190

191
            if self.quantization is None:
192
193
                self.quantization = quant_method
            elif self.quantization != quant_method:
194
195
                raise ValueError(
                    "Quantization method specified in the model config "
196
                    f"({quant_method}) does not match the quantization "
197
198
199
200
201
202
203
204
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
205
206
207
            if is_hip(
            ) and self.quantization in rocm_not_supported_quantization:
                raise ValueError(
208
209
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm.")
210
211
212
213
214
            if self.quantization != "marlin":
                logger.warning(
                    f"{self.quantization} quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.")
215

216
217
218
219
220
221
    def _verify_cuda_graph(self) -> None:
        if self.max_context_len_to_capture is None:
            self.max_context_len_to_capture = self.max_model_len
        self.max_context_len_to_capture = min(self.max_context_len_to_capture,
                                              self.max_model_len)

222
223
224
225
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
226
        total_num_attention_heads = self.hf_text_config.num_attention_heads
227
228
229
230
231
232
233
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

234
        total_num_hidden_layers = self.hf_text_config.num_hidden_layers
235
236
237
238
239
240
241
        pipeline_parallel_size = parallel_config.pipeline_parallel_size
        if total_num_hidden_layers % pipeline_parallel_size != 0:
            raise ValueError(
                f"Total number of hidden layers ({total_num_hidden_layers}) "
                "must be divisible by pipeline parallel size "
                f"({pipeline_parallel_size}).")

242
    def get_sliding_window(self) -> Optional[int]:
243
244
245
246
247
248
        """Get the sliding window size, or None if disabled.
        """

        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
        # addition to sliding window size. We check if that field is present
        # and if it's False, return None.
249
250
        if (hasattr(self.hf_text_config, "use_sliding_window")
                and not self.hf_text_config.use_sliding_window):
251
            return None
252
        return getattr(self.hf_text_config, "sliding_window", None)
253
254

    def get_vocab_size(self) -> int:
255
        return self.hf_text_config.vocab_size
256

257
    def get_hidden_size(self) -> int:
258
        return self.hf_text_config.hidden_size
259
260

    def get_head_size(self) -> int:
261
262
        if hasattr(self.hf_text_config, "head_dim"):
            return self.hf_text_config.head_dim
263
        # FIXME(woosuk): This may not be true for all models.
264
265
        return (self.hf_text_config.hidden_size //
                self.hf_text_config.num_attention_heads)
266

267
268
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
269
        # For GPTBigCode & Falcon:
270
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
271
272
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
273
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
274
        new_decoder_arch_falcon = (
275
            self.hf_config.model_type in falcon_model_types
276
            and getattr(self.hf_config, "new_decoder_architecture", False))
277
        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
278
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
279
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
280
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
281
            return 1
282

283
284
285
286
287
        # For DBRX and MPT
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            return getattr(self.hf_config.attn_config, "kv_n_heads",
                           self.hf_config.num_attention_heads)

288
289
290
291
292
293
294
295
296
297
        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
298
            num_kv_heads = getattr(self.hf_text_config, attr, None)
299
300
301
302
303
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
304
        return self.hf_text_config.num_attention_heads
305
306
307
308
309
310
311
312
313
314

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
315
316

    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
317
        total_num_hidden_layers = self.hf_text_config.num_hidden_layers
318
319
320
321
        return total_num_hidden_layers // parallel_config.pipeline_parallel_size


class CacheConfig:
322
323
324
325
326
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
327
            vLLM execution.
328
        swap_space: Size of the CPU swap space per GPU (in GiB).
329
        cache_dtype: Data type for kv cache storage.
330
331
        forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the
            profiled num_gpu_blocks if specified. Does nothing if None.
332
    """
333

334
335
336
337
338
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        swap_space: int,
339
        cache_dtype: str,
340
        forced_num_gpu_blocks: Optional[int] = None,
341
        sliding_window: Optional[int] = None,
342
        enable_prefix_caching: bool = False,
343
344
345
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
346
        self.swap_space_bytes = swap_space * _GB
347
        self.forced_num_gpu_blocks = forced_num_gpu_blocks
348
        self.cache_dtype = cache_dtype
349
        self.sliding_window = sliding_window
350
        self.enable_prefix_caching = enable_prefix_caching
351
        self._verify_args()
352
        self._verify_cache_dtype()
353
354
355
356
357

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

358
    def metrics_info(self):
359
360
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
361
362
        return {key: str(value) for key, value in self.__dict__.items()}

363
364
365
366
367
368
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

369
370
371
372
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
        elif self.cache_dtype == "fp8_e5m2":
373
374
375
            if is_hip():
                raise NotImplementedError(
                    "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
376
            nvcc_cuda_version = get_nvcc_cuda_version()
377
            if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
378
379
380
381
382
383
384
385
386
387
388
389
                raise ValueError(
                    "FP8 is not supported when cuda version is lower than 11.8."
                )
            logger.info(
                "Using fp8_e5m2 data type to store kv cache. It reduces "
                "the GPU memory footprint and boosts the performance. "
                "But it may cause slight accuracy drop. "
                "Currently we only support fp8 without scaling factors and "
                "make e5m2 as a default format.")
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

390
391
392
393
394
395
396
397
398
399
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

400
401
402
        msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
               f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
               "allocated for the swap space.")
403
404
405
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
406
            logger.warning("Possibly too large swap space. " + msg)
407

408

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
@dataclass
class TokenizerPoolConfig:
    """Configuration for the tokenizer pool.
    
    Args:
        pool_size: Number of tokenizer workers in the pool.
        pool_type: Type of the pool.
        extra_config: Additional config for the pool.
            The way the config will be used depends on the
            pool type.
    """
    pool_size: int
    pool_type: str
    extra_config: dict

    def __post_init__(self):
        if self.pool_type not in ("ray", ):
            raise ValueError(f"Unknown pool type: {self.pool_type}")
        if not isinstance(self.extra_config, dict):
            raise ValueError("extra_config must be a dictionary.")

    @classmethod
    def create_config(
        cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
        tokenizer_pool_extra_config: Optional[Union[str, dict]]
    ) -> Optional["TokenizerPoolConfig"]:
        """Create a TokenizerPoolConfig from the given parameters.
        
        If tokenizer_pool_size is 0, return None.
        
        Args:
            tokenizer_pool_size: Number of tokenizer workers in the pool.
            tokenizer_pool_type: Type of the pool.
            tokenizer_pool_extra_config: Additional config for the pool.
                The way the config will be used depends on the
                pool type. This can be a JSON string (will be parsed).
        """
        if tokenizer_pool_size:
            if isinstance(tokenizer_pool_extra_config, str):
                tokenizer_pool_extra_config_parsed = json.loads(
                    tokenizer_pool_extra_config)
            else:
                tokenizer_pool_extra_config_parsed = (
                    tokenizer_pool_extra_config or {})
            tokenizer_pool_config = cls(tokenizer_pool_size,
                                        tokenizer_pool_type,
                                        tokenizer_pool_extra_config_parsed)
        else:
            tokenizer_pool_config = None
        return tokenizer_pool_config


461
class ParallelConfig:
462
463
464
465
466
467
468
469
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
        worker_use_ray: Whether to use Ray for model workers. Will be set to
            True if either pipeline_parallel_size or tensor_parallel_size is
            greater than 1.
zspo's avatar
zspo committed
470
471
472
        max_parallel_loading_workers: Maximum number of multiple batches
            when load model sequentially. To avoid RAM OOM when using tensor
            parallel and large models.
473
474
        disable_custom_all_reduce: Disable the custom all-reduce kernel and
            fall back to NCCL.
475
476
        tokenizer_pool_config: Config for the tokenizer pool.
            If None, will use synchronous tokenization.
477
478
        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
479
    """
480

481
482
483
484
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
485
        worker_use_ray: bool,
486
        max_parallel_loading_workers: Optional[int] = None,
487
        disable_custom_all_reduce: bool = False,
488
        tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
489
        ray_workers_use_nsight: bool = False,
490
        placement_group: Optional["PlacementGroup"] = None,
491
492
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
493
        self.tensor_parallel_size = tensor_parallel_size
494
        self.worker_use_ray = worker_use_ray
495
        self.max_parallel_loading_workers = max_parallel_loading_workers
496
        self.disable_custom_all_reduce = disable_custom_all_reduce
497
        self.tokenizer_pool_config = tokenizer_pool_config
498
        self.ray_workers_use_nsight = ray_workers_use_nsight
499
        self.placement_group = placement_group
500

501
        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
502
        if self.world_size > 1:
503
            self.worker_use_ray = True
504
505
506
507
508
509
        self._verify_args()

    def _verify_args(self) -> None:
        if self.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is not supported yet.")
510
511
512
513
514
515
516
517
518
519
520
        if not self.disable_custom_all_reduce and self.world_size > 1:
            if is_hip():
                self.disable_custom_all_reduce = True
                logger.info(
                    "Disabled the custom all-reduce kernel because it is not "
                    "supported on AMD GPUs.")
            elif self.pipeline_parallel_size > 1:
                self.disable_custom_all_reduce = True
                logger.info(
                    "Disabled the custom all-reduce kernel because it is not "
                    "supported with pipeline parallelism.")
521
522
523
        if self.ray_workers_use_nsight and not self.worker_use_ray:
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
524

525
526

class SchedulerConfig:
527
528
529
530
531
532
533
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
534
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
535
            and generated text).
536
537
538
539
540
        use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
        num_lookahead_slots: The number of slots to allocate per sequence per
            step, beyond the known token ids. This is used in speculative
            decoding to store KV activations of tokens which may or may not be
            accepted.
541
542
        delay_factor: Apply a delay (of delay factor multiplied by previous
            prompt latency) before scheduling next prompt.
543
544
        enable_chunked_prefill: If True, prefill requests can be chunked based
            on the remaining max_num_batched_tokens.
545
    """
546

547
548
549
550
551
    def __init__(
        self,
        max_num_batched_tokens: Optional[int],
        max_num_seqs: int,
        max_model_len: int,
552
        use_v2_block_manager: bool = False,
553
        num_lookahead_slots: int = 0,
554
        delay_factor: float = 0.0,
555
        enable_chunked_prefill: bool = False,
556
557
558
559
560
561
562
    ) -> None:
        if max_num_batched_tokens is not None:
            self.max_num_batched_tokens = max_num_batched_tokens
        else:
            # If max_model_len is too short, use 2048 as the default value for
            # higher throughput.
            self.max_num_batched_tokens = max(max_model_len, 2048)
563
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
564
        self.max_model_len = max_model_len
565
        self.use_v2_block_manager = use_v2_block_manager
566
567
        self.num_lookahead_slots = num_lookahead_slots
        self.delay_factor = delay_factor
568
        self.chunked_prefill_enabled = enable_chunked_prefill
569

570
571
572
573
574
575
576
577
578
579
580
        self._verify_args()

    def _verify_args(self) -> None:
        if self.max_num_batched_tokens < self.max_model_len:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
581

582
583
584
585
586
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
587

588
589
590
591
592
593
        if self.num_lookahead_slots < 0:
            raise ValueError(
                "num_lookahead_slots "
                f"({self.num_lookahead_slots}) must be greater than or "
                "equal to 0.")

594

595
596
class DeviceConfig:

597
598
599
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
600
            if is_neuron():
601
                self.device_type = "neuron"
602
603
            elif is_cpu():
                self.device_type = "cpu"
604
            else:
605
606
607
                # We don't call torch.cuda.is_available() here to
                # avoid initializing CUDA before workers are forked
                self.device_type = "cuda"
608
609
610
611
612
613
614
615
616
617
618
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
        if self.device_type in ["neuron"]:
            self.device = torch.device("cpu")
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

619

620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
class SpeculativeConfig:
    """Configuration for speculative decoding.

    The configuration is currently specialized to draft-model speculative
    decoding with top-1 proposals.
    """

    @staticmethod
    def maybe_create_spec_config(
        target_model_config: ModelConfig,
        target_parallel_config: ParallelConfig,
        target_dtype: str,
        speculative_model: Optional[str],
        num_speculative_tokens: Optional[int],
    ) -> Optional["SpeculativeConfig"]:
        """Create a SpeculativeConfig if possible, else return None.

        This function attempts to create a SpeculativeConfig object based on the
        provided parameters. If the necessary conditions are met, it returns an
        instance of SpeculativeConfig. Otherwise, it returns None.

        Args:
            target_model_config (ModelConfig): The configuration of the target
                model.
            target_parallel_config (ParallelConfig): The parallel configuration
                for the target model.
            target_dtype (str): The data type used for the target model.
            speculative_model (Optional[str]): The name of the speculative
                model, if provided.
            num_speculative_tokens (Optional[int]): The number of speculative
                tokens, if provided.

        Returns:
            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
                the necessary conditions are met, else None.
        """

        if (speculative_model is None and num_speculative_tokens is None):
            return None

        if speculative_model is not None and num_speculative_tokens is None:
            raise ValueError(
                "Expected both speculative_model and "
                "num_speculative_tokens to be provided, but found "
                f"{speculative_model=} and {num_speculative_tokens=}.")

        # TODO: The user should be able to specify revision/quantization/max
        # model len for the draft model. It is not currently supported.
        draft_revision = None
        draft_code_revision = None
        draft_quantization = None
        draft_max_model_len = None

        draft_model_config = ModelConfig(
            model=speculative_model,
            tokenizer=target_model_config.tokenizer,
            tokenizer_mode=target_model_config.tokenizer_mode,
            trust_remote_code=target_model_config.trust_remote_code,
            download_dir=target_model_config.download_dir,
            load_format=target_model_config.load_format,
            dtype=target_model_config.dtype,
            seed=target_model_config.seed,
            revision=draft_revision,
            code_revision=draft_code_revision,
            tokenizer_revision=target_model_config.tokenizer_revision,
            max_model_len=draft_max_model_len,
            quantization=draft_quantization,
            enforce_eager=target_model_config.enforce_eager,
            max_context_len_to_capture=target_model_config.
            max_context_len_to_capture,
            max_logprobs=target_model_config.max_logprobs,
        )

        draft_parallel_config = (
            SpeculativeConfig.create_draft_parallel_config(
                target_parallel_config))

        return SpeculativeConfig(
            draft_model_config,
            draft_parallel_config,
            num_speculative_tokens,
        )

    @staticmethod
    def create_draft_parallel_config(
            target_parallel_config: ParallelConfig) -> ParallelConfig:
        """Create a parallel config for use by the draft worker.

        This is mostly a copy of the target parallel config. In the future the
        draft worker can have a different parallel strategy, e.g. TP=1.
        """
        draft_parallel_config = ParallelConfig(
            pipeline_parallel_size=target_parallel_config.
            pipeline_parallel_size,
            tensor_parallel_size=target_parallel_config.tensor_parallel_size,
            worker_use_ray=target_parallel_config.worker_use_ray,
            max_parallel_loading_workers=target_parallel_config.
            max_parallel_loading_workers,
            disable_custom_all_reduce=target_parallel_config.
            disable_custom_all_reduce,
            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
            ray_workers_use_nsight=target_parallel_config.
            ray_workers_use_nsight,
            placement_group=target_parallel_config.placement_group,
        )

        return draft_parallel_config

    def __init__(
        self,
        draft_model_config: ModelConfig,
        draft_parallel_config: ParallelConfig,
        num_speculative_tokens: int,
    ):
        """Create a SpeculativeConfig object.

        Args:
            draft_model_config: ModelConfig for the draft model.
            draft_parallel_config: ParallelConfig for the draft model.
            num_speculative_tokens: The number of tokens to sample from the
                draft model before scoring with the target model.
        """
        self.draft_model_config = draft_model_config
        self.draft_parallel_config = draft_parallel_config
        self.num_speculative_tokens = num_speculative_tokens

        self._verify_args()

    def _verify_args(self) -> None:
        if self.num_speculative_tokens <= 0:
            raise ValueError("Expected num_speculative_tokens to be greater "
                             f"than zero ({self.num_speculative_tokens}).")

        if self.draft_model_config:
            self.draft_model_config.verify_with_parallel_config(
                self.draft_parallel_config)

    @property
    def num_lookahead_slots(self) -> int:
        """The number of additional slots the scheduler should allocate per
        step, in addition to the slots allocated for each known token.

        This is equal to the number of speculative tokens, as each speculative
        token must be scored.
        """
        return self.num_speculative_tokens

    def __repr__(self) -> str:
        draft_model = self.draft_model_config.model
        num_spec_tokens = self.num_speculative_tokens
        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
    max_cpu_loras: Optional[int] = None
    lora_dtype: Optional[torch.dtype] = None
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256

    def __post_init__(self):
        # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
        possible_max_ranks = (8, 16, 32, 64)
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
802
                f"max_loras ({self.max_loras})")
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
        if model_config.quantization is not None:
            raise ValueError(
                "LoRA is not supported with quantized models yet.")

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
        if scheduler_config.max_num_batched_tokens > 65528:
            raise ValueError(
                "Due to limitations of the custom LoRA CUDA kernel, "
                "max_num_batched_tokens must be <= 65528 when "
                "LoRA is enabled.")


821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
@dataclass
class VisionLanguageConfig:
    """Configs the input data format and how models should run for
    vision language models."""

    class ImageInputType(enum.Enum):
        """Image input type into the vision language model.

        An image roughly goes through the following transformation:
        Raw image --> pixel values --> image features --> image embeddings.

        The difference between different image input types is where the
        image encoder (pixel values --> image features) is run.
        Different image input types also correspond to different tensor shapes.

        For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
        IMAGE_FEATURES: (1, 576, 1024).
        """
        PIXEL_VALUES = enum.auto()
        IMAGE_FEATURES = enum.auto()

    image_input_type: ImageInputType
    # The input id corresponding to image token.
    image_token_id: int
    # Used for running `run_prefill_max_token`.
    # For models that support varying resolution, this corresponds to
    # worst case scenario (biggest supported resolution).
    image_input_shape: tuple
    image_feature_size: int

    @classmethod
    def get_image_input_enum_type(
            cls, value: str) -> "VisionLanguageConfig.ImageInputType":
        """Get the image input type from a string."""
        try:
            return cls.ImageInputType[value.upper()]
        except KeyError as e:
            raise ValueError(f"{value} is not a valid choice. "
                             f"Expecting to choose from "
                             f"{[x.name for x in cls.ImageInputType]}.") from e


863
864
865
866
867
868
869
870
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

871
872
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]

873
874
875

def _get_and_verify_dtype(
    config: PretrainedConfig,
876
    dtype: Union[str, torch.dtype],
877
878
879
880
881
882
883
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

884
885
886
887
888
889
890
891
892
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
                # Following the common practice, we use float16 for float32
                # models.
                torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
893
        else:
894
895
896
897
898
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
899
    else:
900
        raise ValueError(f"Unknown dtype: {dtype}")
901

902
903
904
905
906
    if is_hip() and torch_dtype == torch.float32:
        rocm_supported_dtypes = [
            k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
            if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
        ]
907
        raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
908
909
                         f"Supported dtypes are {rocm_supported_dtypes}")

910
911
912
913
914
915
916
917
918
    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
919
            # Casting between float16 and bfloat16 is allowed with a warning.
920
            logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
921
922

    return torch_dtype
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
938
939
        # ChatGLM2
        "seq_length",
940
941
        # Command-R
        "model_max_length",
942
943
944
945
946
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
947
    max_len_key = None
948
    for key in possible_keys:
949
950
951
952
953
        max_len = getattr(hf_config, key, None)
        if max_len is not None:
            max_len_key = key if max_len < derived_max_model_len \
                else max_len_key
            derived_max_model_len = min(derived_max_model_len, max_len)
954
    if derived_max_model_len == float("inf"):
955
956
957
958
959
960
961
962
963
964
965
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
            f"{possible_keys}. Assuming the model's maximum length is "
            f"{default_max_len}.")
        derived_max_model_len = default_max_len
966

967
968
969
970
    rope_scaling = getattr(hf_config, "rope_scaling", None)
    if rope_scaling is not None:
        assert "factor" in rope_scaling
        scaling_factor = rope_scaling["factor"]
Antoni Baum's avatar
Antoni Baum committed
971
972
973
        if rope_scaling["type"] == "yarn":
            derived_max_model_len = rope_scaling[
                "original_max_position_embeddings"]
974
975
        derived_max_model_len *= scaling_factor

976
977
978
    if max_model_len is None:
        max_model_len = derived_max_model_len
    elif max_model_len > derived_max_model_len:
979
980
981
982
983
984
985
986
987
988
989
990
991
992
        # Some models might have a separate key for specifying model_max_length
        # that will be bigger than derived_max_model_len. We compare user input
        # with model_max_length and allow this override when it's smaller.
        model_max_length = getattr(hf_config, "model_max_length", None)
        if model_max_length is not None and max_model_len <= model_max_length:
            pass
        else:
            raise ValueError(
                f"User-specified max_model_len ({max_model_len}) is greater "
                "than the derived max_model_len "
                f"({max_len_key}={derived_max_model_len} or model_max_length="
                f"{model_max_length} in model's config.json). This may lead "
                "to incorrect model outputs or CUDA errors. Make sure the "
                "value is correct and within the model context size.")
993
    return int(max_model_len)
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026


@dataclass(frozen=True)
class EngineConfig:
    """Dataclass which contains all engine-related configuration. This
    simplifies passing around the distinct configurations in the codebase.
    """

    model_config: ModelConfig
    cache_config: CacheConfig
    parallel_config: ParallelConfig
    scheduler_config: SchedulerConfig
    device_config: DeviceConfig
    lora_config: Optional[LoRAConfig]
    vision_language_config: Optional[VisionLanguageConfig]
    speculative_config: Optional[SpeculativeConfig]

    def __post_init__(self):
        """Verify configs are valid & consistent with each other.
        """
        self.model_config.verify_with_parallel_config(self.parallel_config)
        self.cache_config.verify_with_parallel_config(self.parallel_config)

        if self.lora_config:
            self.lora_config.verify_with_model_config(self.model_config)
            self.lora_config.verify_with_scheduler_config(
                self.scheduler_config)

    def to_dict(self):
        """Return the configs as a dictionary, for use in **kwargs.
        """
        return dict(
            (field.name, getattr(self, field.name)) for field in fields(self))