cache.py 8.72 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import hashlib
from dataclasses import field
6
from typing import TYPE_CHECKING, Any, Literal
7

8
from pydantic import Field, SkipValidation, field_validator
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from pydantic.dataclasses import dataclass

from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, get_cpu_memory

if TYPE_CHECKING:
    from vllm.config.parallel import ParallelConfig
else:
    ParallelConfig = Any

logger = init_logger(__name__)

BlockSize = Literal[1, 8, 16, 32, 64, 128]
23
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
24
MambaDType = Literal["auto", "float32"]
25
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
26
27
28
29
30
31
32
33


@config
@dataclass
class CacheConfig:
    """Configuration for the KV cache."""

    block_size: SkipValidation[BlockSize] = None  # type: ignore
34
35
    """Size of a contiguous cache block in number of tokens. On CUDA devices,
    only block sizes up to 32 are supported.
36
37
38
39

    This config has no static default. If left unspecified by the user, it will
    be set in `Platform.check_and_update_config()` based on the current
    platform."""
40
    gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
41
42
43
44
45
46
47
    """The fraction of GPU memory to be used for the model executor, which can
    range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
    utilization. If unspecified, will use the default value of 0.9. This is a
    per-instance limit, and only applies to the current vLLM instance. It does
    not matter if you have another vLLM instance running on the same GPU. For
    example, if you have two vLLM instances running on the same GPU, you can
    set the GPU memory utilization to 0.5 for each instance."""
48
    swap_space: float = Field(default=4, ge=0)
49
50
51
52
    """Size of the CPU swap space per GPU (in GiB)."""
    cache_dtype: CacheDType = "auto"
    """Data type for kv cache storage. If "auto", will use model data type.
    CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
53
54
55
56
57
    fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
    Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
    bfloat16 instead, this is an invalid option for models that do not default
    to fp8.
    """
58
59
60
    is_attention_free: bool = False
    """Whether the model is attention-free. This is primarily set in
    `ModelConfig` and that value should be manually duplicated here."""
61
    num_gpu_blocks_override: int | None = None
62
63
    """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
    if specified. Does nothing if `None`. Used for testing preemption."""
64
    sliding_window: int | None = None
65
66
    """Sliding window size for the KV cache. This is primarily set in
    `ModelConfig` and that value should be manually duplicated here."""
67
    enable_prefix_caching: bool | None = None
68
69
    """Whether to enable prefix caching. Enabled by default for V1."""
    prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
70
    """Set the hash algorithm for prefix caching:\n
71
72
73
    - "sha256" uses Pickle for object serialization before hashing.\n
    - "sha256_cbor" provides a reproducible, cross-language compatible hash. It
    serializes objects using canonical CBOR and hashes them with SHA-256."""
74
    cpu_offload_gb: float = Field(default=0, ge=0)
75
76
77
78
79
80
81
82
83
84
85
86
    """The space in GiB to offload to CPU, per GPU. Default is 0, which means
    no offloading. Intuitively, this argument can be seen as a virtual way to
    increase the GPU memory size. For example, if you have one 24 GB GPU and
    set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
    load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
    Note that this requires fast CPU-GPU interconnect, as part of the model is
    loaded from CPU memory to GPU memory on the fly in each model forward pass.
    """
    calculate_kv_scales: bool = False
    """This enables dynamic calculation of `k_scale` and `v_scale` when
    kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
    checkpoint if available. Otherwise, the scales will default to 1.0."""
87
    cpu_kvcache_space_bytes: int | None = None
88
    """(CPU backend only) CPU key-value cache space."""
89
    mamba_page_size_padded: int | None = None
90
91
    """ Optional override for mamba page size; used by hybrid mamba/attention
    models to ensure exact alignment with attention page size."""
92
    mamba_block_size: int | None = None
93
    """Size of a contiguous cache block in number of tokens for mamba cache."""
94
95
96
97
98
99
100
101
102
    mamba_cache_dtype: MambaDType = "auto"
    """The data type to use for the Mamba cache (both the conv as well as the
    ssm state). If set to 'auto', the data type will be inferred from the model
    config."""
    mamba_ssm_cache_dtype: MambaDType = "auto"
    """The data type to use for the Mamba cache (ssm state only, conv state will
    still be controlled by mamba_cache_dtype). If set to 'auto', the data type
    for the ssm state will be determined by mamba_cache_dtype."""

103
    # Will be set after profiling.
104
    num_gpu_blocks: int | None = field(default=None, init=False)
105
    """The number of blocks to allocate for GPU memory."""
106
    num_cpu_blocks: int | None = field(default=None, init=False)
107
108
109
110
111
112
113
114
    """The number of blocks to allocate for CPU memory."""

    kv_sharing_fast_prefill: bool = False
    """This feature is work in progress and no prefill optimization takes place
    with this flag enabled currently.

    In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
    some layers can skip tokens corresponding to prefill. This flag enables
115
    attention metadata for eligible layers to be overridden with metadata
116
    necessary for implementing this optimization in some models (e.g. Gemma3n)
117
118
    """

119
    kv_cache_memory_bytes: int | None = None
120
121
122
123
124
    """Size of KV Cache per GPU in bytes. By default, this is set to None
    and vllm can automatically infer the kv cache size based on
    gpu_memory_utilization. However, users may want to manually specify
    the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
    control of how much memory gets used when compared with using
125
    gpu_memory_utilization. Note that kv_cache_memory_bytes
126
127
    (when not-None) ignores gpu_memory_utilization"""

128
129
130
131
132
133
134
135
136
137
138
139
140
141
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        factors.append(self.cache_dtype)
142
143
        factors.append(self.mamba_cache_dtype)
        factors.append(self.mamba_ssm_cache_dtype)
144
        # `cpu_offload_gb` does not use `torch.compile` yet.
145
        hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
146
147
148
149
150
151
152
        return hash_str

    def metrics_info(self):
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
        return {key: str(value) for key, value in self.__dict__.items()}

153
154
155
156
157
158
159
160
161
    @field_validator("cache_dtype", mode="after")
    @classmethod
    def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
        if cache_dtype.startswith("fp8"):
            logger.info(
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
                "Meanwhile, it may cause accuracy drop without a proper "
                "scaling factor."
162
            )
163
        return cache_dtype
164
165
166
167
168

    def verify_with_parallel_config(
        self,
        parallel_config: ParallelConfig,
    ) -> None:
169
        swap_space_bytes = self.swap_space * GiB_bytes
170
171
172
173
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
174
        cpu_memory_usage = swap_space_bytes * num_gpus_per_node
175

176
177
178
179
180
        msg = (
            f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
            f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
            "is allocated for the swap space."
        )
181
182
183
184
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
            logger.warning("Possibly too large swap space. %s", msg)