cache.py 11 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import field
5
from typing import ClassVar, Literal
6

7
from pydantic import Field, SkipValidation, field_validator, model_validator
8
9
10
11
12
13

from vllm.config.utils import config
from vllm.logger import init_logger

logger = init_logger(__name__)

14
15
CacheDType = Literal[
    "auto",
16
    "float16",
17
18
19
20
21
22
23
    "bfloat16",
    "fp8",
    "fp8_e4m3",
    "fp8_e5m2",
    "fp8_inc",
    "fp8_ds_mla",
]
24
MambaDType = Literal["auto", "float32", "float16"]
25
MambaCacheMode = Literal["all", "align", "none"]
26
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
27
KVOffloadingBackend = Literal["native", "lmcache"]
28
29
30
31
32
33


@config
class CacheConfig:
    """Configuration for the KV cache."""

34
    DEFAULT_BLOCK_SIZE: ClassVar[int] = 16
35

36
37
38
39
40
    block_size: SkipValidation[int] = None  # type: ignore[assignment]
    """Size of a contiguous cache block in number of tokens.
    Accepts None (meaning "use default"). After construction, always int."""
    user_specified_block_size: bool = field(default=False, init=False)
    """Whether block_size was explicitly provided. Derived automatically."""
41
    gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
42
43
44
45
46
47
48
49
50
51
    """The fraction of GPU memory to be used for the model executor, which can
    range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
    utilization. If unspecified, will use the default value of 0.9. This is a
    per-instance limit, and only applies to the current vLLM instance. It does
    not matter if you have another vLLM instance running on the same GPU. For
    example, if you have two vLLM instances running on the same GPU, you can
    set the GPU memory utilization to 0.5 for each instance."""
    cache_dtype: CacheDType = "auto"
    """Data type for kv cache storage. If "auto", will use model data type.
    CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
52
53
54
55
56
    fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
    Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
    bfloat16 instead, this is an invalid option for models that do not default
    to fp8.
    """
57
58
59
    is_attention_free: bool = False
    """Whether the model is attention-free. This is primarily set in
    `ModelConfig` and that value should be manually duplicated here."""
60
    num_gpu_blocks_override: int | None = None
61
62
    """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
    if specified. Does nothing if `None`. Used for testing preemption."""
63
    sliding_window: int | None = None
64
65
    """Sliding window size for the KV cache. This is primarily set in
    `ModelConfig` and that value should be manually duplicated here."""
66
67
    enable_prefix_caching: bool = True
    """Whether to enable prefix caching."""
68
    prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
69
70
71
72
    """Set the hash algorithm for prefix caching:

    - "sha256" uses Pickle for object serialization before hashing. This is the current
      default, as SHA256 is the most secure choice to avoid potential hash collisions.
73
    - "sha256_cbor" provides a reproducible, cross-language compatible hash. It
74
      serializes objects using canonical CBOR and hashes them with SHA-256.
75
    - "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
76
77
78
79
80
81
      non-cryptographic hashing. Requires the optional ``xxhash`` package.
      IMPORTANT: Use of a hashing algorithm that is not considered  cryptographically
      secure theoretically increases the risk of hash collisions, which can cause
      undefined behavior or even leak private information in multi-tenant environments.
      Even if collisions are still very unlikely, it is important to consider your
      security risk tolerance against the performance benefits before turning this on.
82
    - "xxhash_cbor" combines canonical CBOR serialization with xxHash for
83
      reproducible hashing. Requires the optional ``xxhash`` package."""
84
    calculate_kv_scales: bool = False
85
86
    """Deprecated: This option is deprecated and will be removed in v0.19.
    It enables dynamic calculation of `k_scale` and `v_scale` when
87
88
    kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
    checkpoint if available. Otherwise, the scales will default to 1.0."""
89
90
91
    kv_cache_dtype_skip_layers: list[str] = field(default_factory=list)
    """Layer patterns to skip KV cache quantization. Accepts layer indices
    (e.g., '0', '2', '4') or attention type names (e.g., 'sliding_window')."""
92
    cpu_kvcache_space_bytes: int | None = None
93
    """(CPU backend only) CPU key-value cache space."""
94
    mamba_page_size_padded: int | None = None
95
96
    """ Optional override for mamba page size; used by hybrid mamba/attention
    models to ensure exact alignment with attention page size."""
97
98
99
100
    mamba_block_size: int | None = Field(default=None, gt=0)
    """Size of a contiguous cache block in number of tokens for mamba cache.
    Can be set only when prefix caching is enabled.
    Value must be a multiple of 8 to align with causal_conv1d kernel."""
101
102
103
104
105
106
107
108
    mamba_cache_dtype: MambaDType = "auto"
    """The data type to use for the Mamba cache (both the conv as well as the
    ssm state). If set to 'auto', the data type will be inferred from the model
    config."""
    mamba_ssm_cache_dtype: MambaDType = "auto"
    """The data type to use for the Mamba cache (ssm state only, conv state will
    still be controlled by mamba_cache_dtype). If set to 'auto', the data type
    for the ssm state will be determined by mamba_cache_dtype."""
109
110
111
112
113
114
115
116
117
    mamba_cache_mode: MambaCacheMode = "none"
    """The cache strategy for Mamba layers.
    - "none": set when prefix caching is disabled.
    - "all": cache the mamba state of all tokens at position i * block_size. This is 
           the default behavior (for models that support it) when prefix caching is
           enabled.
    - "align": only cache the mamba state of the last token of each scheduler step and
           when the token is at position i * block_size.
    """
118

119
    # Will be set after profiling.
120
    num_gpu_blocks: int | None = field(default=None, init=False)
121
    """The number of blocks to allocate for GPU memory."""
122
    num_cpu_blocks: int | None = field(default=None, init=False)
123
124
125
126
127
128
129
130
    """The number of blocks to allocate for CPU memory."""

    kv_sharing_fast_prefill: bool = False
    """This feature is work in progress and no prefill optimization takes place
    with this flag enabled currently.

    In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
    some layers can skip tokens corresponding to prefill. This flag enables
131
    attention metadata for eligible layers to be overridden with metadata
132
    necessary for implementing this optimization in some models (e.g. Gemma3n)
133
134
    """

135
    kv_cache_memory_bytes: int | None = None
136
137
138
139
140
    """Size of KV Cache per GPU in bytes. By default, this is set to None
    and vllm can automatically infer the kv cache size based on
    gpu_memory_utilization. However, users may want to manually specify
    the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
    control of how much memory gets used when compared with using
141
    gpu_memory_utilization. Note that kv_cache_memory_bytes
142
143
    (when not-None) ignores gpu_memory_utilization"""

144
145
146
    kv_offloading_size: float | None = None
    """Size of the KV cache offloading buffer in GiB. When TP > 1, this is
    the total buffer size summed across all TP ranks. By default, this is set
147
148
    to None, which means no KV offloading is enabled. When set, vLLM will
    enable KV cache offloading to CPU using the kv_offloading_backend."""
149

150
    kv_offloading_backend: KVOffloadingBackend = "native"
151
    """The backend to use for KV cache offloading. Supported backends include
152
153
    'native' (vLLM native CPU offloading), 'lmcache'.
    KV offloading is only activated when kv_offloading_size is set."""
154

155
156
157
158
159
160
161
162
163
164
165
166
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
167
168
169
170
171
172
173
174
175
        ignored_factors = {
            # Runtime/derived knobs that don't affect compiled graph shape
            "gpu_memory_utilization",
            "is_attention_free",
            "num_gpu_blocks_override",
            "enable_prefix_caching",
            "prefix_caching_hash_algo",
            "cpu_kvcache_space_bytes",
            "mamba_page_size_padded",
176
177
            "user_specified_block_size",
            "_block_size_resolved",
178
179
180
181
182
183
184
185
186
187
188
            # Post-init/derived counters
            "num_gpu_blocks",
            "num_cpu_blocks",
            # WIP feature toggle not impacting compiled graph shape
            "kv_sharing_fast_prefill",
        }

        from vllm.config.utils import get_hash_factors, hash_factors

        factors = get_hash_factors(self, ignored_factors)
        return hash_factors(factors)
189
190
191
192
193
194

    def metrics_info(self):
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
        return {key: str(value) for key, value in self.__dict__.items()}

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    _block_size_resolved: bool = field(default=False, init=False)
    """Guard against pydantic re-running _apply_block_size_default."""

    @model_validator(mode="after")
    def _apply_block_size_default(self) -> "CacheConfig":
        # Pydantic re-runs validators when CacheConfig is nested inside
        # another pydantic model (e.g. VllmConfig). Guard against that.
        if self._block_size_resolved:
            return self
        object.__setattr__(self, "_block_size_resolved", True)
        if self.block_size is None:
            object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE)
        else:
            object.__setattr__(self, "user_specified_block_size", True)
        return self

211
212
213
214
215
216
217
218
219
220
221
222
    @field_validator("calculate_kv_scales", mode="after")
    @classmethod
    def _warn_deprecated_calculate_kv_scales(cls, calculate_kv_scales: bool) -> bool:
        if calculate_kv_scales:
            logger.warning(
                "The `--calculate-kv-scales` option is deprecated and will "
                "be removed in v0.19. The scales will be loaded from the "
                "model checkpoint if available, otherwise they default to "
                "1.0."
            )
        return calculate_kv_scales

223
224
225
226
227
228
229
230
231
    @field_validator("cache_dtype", mode="after")
    @classmethod
    def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
        if cache_dtype.startswith("fp8"):
            logger.info(
                "Using fp8 data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
                "Meanwhile, it may cause accuracy drop without a proper "
                "scaling factor."
232
            )
233
        return cache_dtype