kv_cache_interface.py 9.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import copy
5
from dataclasses import dataclass, fields
Chen Zhang's avatar
Chen Zhang committed
6
from math import prod
7
from typing import Optional
8
9

import torch
10
from typing_extensions import Self
11

12
from vllm.config import VllmConfig
13
from vllm.logger import init_logger
14
from vllm.multimodal import MULTIMODAL_REGISTRY
15
16
17
18
19
from vllm.utils import cdiv, get_dtype_size

logger = init_logger(__name__)


20
@dataclass(frozen=True)
21
class KVCacheSpec:
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

39
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
40
        """
41
        The maximum possible memory usage of this KV cache in bytes.
42
43

        Returns:
44
            The KV cache size in bytes
45
46
47
        """
        raise NotImplementedError

48
49
50
51
52
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
53
54
        assert all(spec == specs[0] for spec in specs[1:]), (
            "All layers in the same KV cache group must be the same.")
55
56
        return copy.deepcopy(specs[0])

57

58
@dataclass(frozen=True)
59
class AttentionSpec(KVCacheSpec):
60
61
62
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype
63
    use_mla: bool
64
65
66

    @property
    def page_size_bytes(self) -> int:
67
68
69
        # For MLA we only store a single latent vector
        coef = 1 if self.use_mla else 2
        return coef * self.block_size * self.num_kv_heads * self.head_size \
70
71
                * get_dtype_size(self.dtype)

72

73
@dataclass(frozen=True)
74
class FullAttentionSpec(AttentionSpec):
75
    sliding_window: Optional[int] = None
76
    attention_chunk_size: Optional[int] = None
77
78
79
80
81
82
83
84
85
    """
    When hybrid allocator is disabled and the model contains both full 
    attention layers and sliding window attention layers, sliding 
    window attention are regarded as full attention in KV cache manager 
    (blocks are allocated for all tokens), while computed as sliding window 
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
    Default to None for not using sliding window attention.
    """
86
87
88
89
90

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

91
92
93
94
95
96
97
98
99
100
101
    @classmethod
    def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
                "same window size.")

102
103
104
105
106
107
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single 
        FullAttentionSpec object.
        """
108
109
110
111
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be "
            "FullAttentionSpec.")

112
113
        sliding_window = set(spec.sliding_window for spec in specs
                             if spec.sliding_window is not None)
114
115
        attention_chunk_size = set(spec.attention_chunk_size for spec in specs
                                   if spec.attention_chunk_size is not None)
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            use_mla=specs[0].use_mla,
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
                    "the same attention spec.")
130
131
132
133
134
        assert (
            (merged_spec.sliding_window is not None) +
            (merged_spec.attention_chunk_size is not None) <= 1
        ), ("Model with both sliding window layers and chunked local attention "
            "layers is not supported.")
135
136
        return merged_spec

137

138
@dataclass(frozen=True)
139
140
141
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

142
143
144
145
146
147
148
149
150
151
152
153
154
155
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)

        # During chunked prefill, we allocate KV cache for at most
        # `self.attention_chunk_size` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
        num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
                         max_model_len)

        return cdiv(num_tokens, self.block_size) * self.page_size_bytes

156

157
@dataclass(frozen=True)
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def __post_init__(self):
        assert not self.use_mla, "MLA is not supported for sliding window"

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
        num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
                         max_model_len)

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
181
182


183
@dataclass(frozen=True)
Chen Zhang's avatar
Chen Zhang committed
184
185
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
186
    dtypes: tuple[torch.dtype]
187
    page_size_padded: Optional[int] = None
188
    mamba_type: str = "mamba2"
Chen Zhang's avatar
Chen Zhang committed
189
190
191

    @property
    def page_size_bytes(self) -> int:
192
193
194
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
            for (shape, dtype) in zip(self.shapes, self.dtypes))
195
196
197
198
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size
Chen Zhang's avatar
Chen Zhang committed
199
200
201
202
203
204
205
206

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # We allocate 1 block for each request now, so max_memory_usage_bytes is
        # the same as page_size_bytes.
        # Need to update this when supporting prefix caching.
        return self.page_size_bytes


207
208
209
210
211
212
213
214
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0


215
216
217
218
219
220
221
222
223
224
225
226
227
228
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
    """
    KV cache spec for cross-attention layers in encoder-decoder models.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # For cross-attention, we need to cache encoder states
        # Get encoder length (e.g., 1500 for Whisper).
        max_encoder_len = MULTIMODAL_REGISTRY.\
            get_encdec_max_encoder_len(vllm_config.model_config)
        return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes


229
230
231
@dataclass
class KVCacheTensor:
    """
232
    A class for specifying how the workers should initialize the KV cache.
233
    """
234
235
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor
236
237


238
239
240
241
242
243
244
245
246
247
248
249
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec


250
251
252
253
254
255
256
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
    """The number of KV cache blocks"""
    num_blocks: int
257
258
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_tensors: list[KVCacheTensor]
259
    """
260
    The kv cache groups of the model.
261
262
263
264
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
265
    """
266
    kv_cache_groups: list[KVCacheGroupSpec]