cache.py 11.9 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import field
5
from typing import ClassVar, Literal
6

7
from pydantic import Field, SkipValidation, field_validator, model_validator
8
9
10

from vllm.config.utils import config
from vllm.logger import init_logger
11
12
13
14
from vllm.utils.torch_utils import (
    is_quantized_kv_cache,
    kv_cache_uses_per_token_head_scales,
)
15
16
17

logger = init_logger(__name__)

18
19
CacheDType = Literal[
    "auto",
20
    "float16",
21
22
23
24
25
26
    "bfloat16",
    "fp8",
    "fp8_e4m3",
    "fp8_e5m2",
    "fp8_inc",
    "fp8_ds_mla",
27
28
29
30
    "turboquant_k8v4",
    "turboquant_4bit_nc",
    "turboquant_k3v4_nc",
    "turboquant_3bit_nc",
31
32
    "int8_per_token_head",
    "fp8_per_token_head",
33
]
34
MambaDType = Literal["auto", "float32", "float16"]
35
MambaCacheMode = Literal["all", "align", "none"]
36
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
37
KVOffloadingBackend = Literal["native", "lmcache"]
38
39
40
41
42
43


@config
class CacheConfig:
    """Configuration for the KV cache."""

44
    DEFAULT_BLOCK_SIZE: ClassVar[int] = 16
45

46
47
48
49
50
    block_size: SkipValidation[int] = None  # type: ignore[assignment]
    """Size of a contiguous cache block in number of tokens.
    Accepts None (meaning "use default"). After construction, always int."""
    user_specified_block_size: bool = field(default=False, init=False)
    """Whether block_size was explicitly provided. Derived automatically."""
51
52
    user_specified_mamba_block_size: bool = field(default=False, init=False)
    """Whether mamba_block_size was explicitly provided. Derived automatically."""
53
    gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
54
55
56
57
58
59
60
61
62
63
    """The fraction of GPU memory to be used for the model executor, which can
    range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
    utilization. If unspecified, will use the default value of 0.9. This is a
    per-instance limit, and only applies to the current vLLM instance. It does
    not matter if you have another vLLM instance running on the same GPU. For
    example, if you have two vLLM instances running on the same GPU, you can
    set the GPU memory utilization to 0.5 for each instance."""
    cache_dtype: CacheDType = "auto"
    """Data type for kv cache storage. If "auto", will use model data type.
    CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
64
65
66
67
68
    fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
    Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
    bfloat16 instead, this is an invalid option for models that do not default
    to fp8.
    """
69
70
71
    is_attention_free: bool = False
    """Whether the model is attention-free. This is primarily set in
    `ModelConfig` and that value should be manually duplicated here."""
72
    num_gpu_blocks_override: int | None = None
73
74
    """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
    if specified. Does nothing if `None`. Used for testing preemption."""
75
    sliding_window: int | None = None
76
77
    """Sliding window size for the KV cache. This is primarily set in
    `ModelConfig` and that value should be manually duplicated here."""
78
79
    enable_prefix_caching: bool = True
    """Whether to enable prefix caching."""
80
    prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
81
82
83
84
    """Set the hash algorithm for prefix caching:

    - "sha256" uses Pickle for object serialization before hashing. This is the current
      default, as SHA256 is the most secure choice to avoid potential hash collisions.
85
    - "sha256_cbor" provides a reproducible, cross-language compatible hash. It
86
      serializes objects using canonical CBOR and hashes them with SHA-256.
87
    - "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
88
89
90
91
92
93
      non-cryptographic hashing. Requires the optional ``xxhash`` package.
      IMPORTANT: Use of a hashing algorithm that is not considered  cryptographically
      secure theoretically increases the risk of hash collisions, which can cause
      undefined behavior or even leak private information in multi-tenant environments.
      Even if collisions are still very unlikely, it is important to consider your
      security risk tolerance against the performance benefits before turning this on.
94
    - "xxhash_cbor" combines canonical CBOR serialization with xxHash for
95
      reproducible hashing. Requires the optional ``xxhash`` package."""
96
    calculate_kv_scales: bool = False
97
98
    """Deprecated: This option is deprecated and will be removed in v0.19.
    It enables dynamic calculation of `k_scale` and `v_scale` when
99
100
    kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
    checkpoint if available. Otherwise, the scales will default to 1.0."""
101
102
103
    kv_cache_dtype_skip_layers: list[str] = field(default_factory=list)
    """Layer patterns to skip KV cache quantization. Accepts layer indices
    (e.g., '0', '2', '4') or attention type names (e.g., 'sliding_window')."""
104
    mamba_page_size_padded: int | None = None
105
106
    """ Optional override for mamba page size; used by hybrid mamba/attention
    models to ensure exact alignment with attention page size."""
107
108
109
110
    mamba_block_size: int | None = Field(default=None, gt=0)
    """Size of a contiguous cache block in number of tokens for mamba cache.
    Can be set only when prefix caching is enabled.
    Value must be a multiple of 8 to align with causal_conv1d kernel."""
111
112
113
114
115
116
117
118
    mamba_cache_dtype: MambaDType = "auto"
    """The data type to use for the Mamba cache (both the conv as well as the
    ssm state). If set to 'auto', the data type will be inferred from the model
    config."""
    mamba_ssm_cache_dtype: MambaDType = "auto"
    """The data type to use for the Mamba cache (ssm state only, conv state will
    still be controlled by mamba_cache_dtype). If set to 'auto', the data type
    for the ssm state will be determined by mamba_cache_dtype."""
119
120
121
    mamba_cache_mode: MambaCacheMode = "none"
    """The cache strategy for Mamba layers.
    - "none": set when prefix caching is disabled.
122
    - "all": cache the mamba state of all tokens at position i * block_size. This is
123
124
125
126
127
           the default behavior (for models that support it) when prefix caching is
           enabled.
    - "align": only cache the mamba state of the last token of each scheduler step and
           when the token is at position i * block_size.
    """
128

129
    # Will be set after profiling.
130
    num_gpu_blocks: int | None = field(default=None, init=False)
131
    """The number of blocks to allocate for GPU memory."""
132
    num_cpu_blocks: int | None = field(default=None, init=False)
133
134
135
136
137
138
139
140
    """The number of blocks to allocate for CPU memory."""

    kv_sharing_fast_prefill: bool = False
    """This feature is work in progress and no prefill optimization takes place
    with this flag enabled currently.

    In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
    some layers can skip tokens corresponding to prefill. This flag enables
141
    attention metadata for eligible layers to be overridden with metadata
142
    necessary for implementing this optimization in some models (e.g. Gemma3n)
143
144
    """

145
    kv_cache_memory_bytes: int | None = None
146
147
148
149
150
    """Size of KV Cache per GPU in bytes. By default, this is set to None
    and vllm can automatically infer the kv cache size based on
    gpu_memory_utilization. However, users may want to manually specify
    the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
    control of how much memory gets used when compared with using
151
    gpu_memory_utilization. Note that kv_cache_memory_bytes
152
153
    (when not-None) ignores gpu_memory_utilization"""

154
155
156
    kv_offloading_size: float | None = None
    """Size of the KV cache offloading buffer in GiB. When TP > 1, this is
    the total buffer size summed across all TP ranks. By default, this is set
157
158
    to None, which means no KV offloading is enabled. When set, vLLM will
    enable KV cache offloading to CPU using the kv_offloading_backend."""
159

160
    kv_offloading_backend: KVOffloadingBackend = "native"
161
    """The backend to use for KV cache offloading. Supported backends include
162
163
    'native' (vLLM native CPU offloading), 'lmcache'.
    KV offloading is only activated when kv_offloading_size is set."""
164

165
166
167
168
169
170
171
172
173
174
175
176
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
177
178
179
180
181
182
183
184
        ignored_factors = {
            # Runtime/derived knobs that don't affect compiled graph shape
            "gpu_memory_utilization",
            "is_attention_free",
            "num_gpu_blocks_override",
            "enable_prefix_caching",
            "prefix_caching_hash_algo",
            "mamba_page_size_padded",
185
            "user_specified_block_size",
186
            "user_specified_mamba_block_size",
187
            "_block_size_resolved",
188
189
190
191
192
193
194
195
196
197
198
            # Post-init/derived counters
            "num_gpu_blocks",
            "num_cpu_blocks",
            # WIP feature toggle not impacting compiled graph shape
            "kv_sharing_fast_prefill",
        }

        from vllm.config.utils import get_hash_factors, hash_factors

        factors = get_hash_factors(self, ignored_factors)
        return hash_factors(factors)
199
200
201
202
203
204

    def metrics_info(self):
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
        return {key: str(value) for key, value in self.__dict__.items()}

205
206
207
208
209
210
211
212
213
214
215
216
217
218
    _block_size_resolved: bool = field(default=False, init=False)
    """Guard against pydantic re-running _apply_block_size_default."""

    @model_validator(mode="after")
    def _apply_block_size_default(self) -> "CacheConfig":
        # Pydantic re-runs validators when CacheConfig is nested inside
        # another pydantic model (e.g. VllmConfig). Guard against that.
        if self._block_size_resolved:
            return self
        object.__setattr__(self, "_block_size_resolved", True)
        if self.block_size is None:
            object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE)
        else:
            object.__setattr__(self, "user_specified_block_size", True)
219
220
        if self.mamba_block_size is not None:
            object.__setattr__(self, "user_specified_mamba_block_size", True)
221
222
        return self

223
224
225
226
227
228
229
230
231
232
233
234
    @field_validator("calculate_kv_scales", mode="after")
    @classmethod
    def _warn_deprecated_calculate_kv_scales(cls, calculate_kv_scales: bool) -> bool:
        if calculate_kv_scales:
            logger.warning(
                "The `--calculate-kv-scales` option is deprecated and will "
                "be removed in v0.19. The scales will be loaded from the "
                "model checkpoint if available, otherwise they default to "
                "1.0."
            )
        return calculate_kv_scales

235
236
237
    @field_validator("cache_dtype", mode="after")
    @classmethod
    def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
238
        if kv_cache_uses_per_token_head_scales(cache_dtype):
239
            logger.info(
240
241
242
243
244
245
246
247
                "Using %s data type to store kv cache. It reduces the GPU "
                "memory footprint and boosts the performance. "
                "Dynamic per-token-head scales will be computed at runtime.",
                str(cache_dtype),
            )
        elif is_quantized_kv_cache(cache_dtype):
            logger.info(
                "Using %s data type to store kv cache. It reduces the GPU "
248
249
                "memory footprint and boosts the performance. "
                "Meanwhile, it may cause accuracy drop without a proper "
250
251
                "scaling factor",
                str(cache_dtype),
252
            )
253
        return cache_dtype