kv_cache_interface.py 13.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import copy
5
from dataclasses import dataclass, fields
Chen Zhang's avatar
Chen Zhang committed
6
from math import prod
7
from typing import Optional
8
9

import torch
10
from typing_extensions import Self
11

12
from vllm.config import VllmConfig
13
14
15
16
17
18
from vllm.logger import init_logger
from vllm.utils import cdiv, get_dtype_size

logger = init_logger(__name__)


19
@dataclass(frozen=True)
20
class KVCacheSpec:
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

38
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
39
        """
40
        The maximum possible memory usage of this KV cache in bytes.
41
42

        Returns:
43
            The KV cache size in bytes
44
45
46
        """
        raise NotImplementedError

47
48
49
50
51
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
52
53
        assert all(spec == specs[0] for spec in specs[1:]), (
            "All layers in the same KV cache group must be the same.")
54
55
        return copy.deepcopy(specs[0])

56

57
@dataclass(frozen=True)
58
class AttentionSpec(KVCacheSpec):
59
60
61
62
63
64
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype

    @property
    def page_size_bytes(self) -> int:
65
        return 2 * self.block_size * self.num_kv_heads * self.head_size \
66
67
                * get_dtype_size(self.dtype)

68

69
@dataclass(frozen=True)
70
class FullAttentionSpec(AttentionSpec):
71
    sliding_window: Optional[int] = None
72
    attention_chunk_size: Optional[int] = None
73
74
75
76
77
78
79
80
81
    """
    When hybrid allocator is disabled and the model contains both full 
    attention layers and sliding window attention layers, sliding 
    window attention are regarded as full attention in KV cache manager 
    (blocks are allocated for all tokens), while computed as sliding window 
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
    Default to None for not using sliding window attention.
    """
82
83
84

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
85
86
87
88
89
90
        dcp_world_size = \
            vllm_config.parallel_config.decode_context_parallel_size
        # Note(hc): each dcp rank only need save
        # (max_model_len//dcp_world_size) tokens locally.
        if dcp_world_size > 1:
            max_model_len = cdiv(max_model_len, dcp_world_size)
91
92
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

93
94
95
96
97
98
99
100
101
102
103
    @classmethod
    def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
                "same window size.")

104
105
106
107
108
109
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single 
        FullAttentionSpec object.
        """
110
111
112
113
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be "
            "FullAttentionSpec.")

114
115
        sliding_window = set(spec.sliding_window for spec in specs
                             if spec.sliding_window is not None)
116
117
        attention_chunk_size = set(spec.attention_chunk_size for spec in specs
                                   if spec.attention_chunk_size is not None)
118
119
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge")
120
121
122
123
124
125
126
127
128
129
130
131
132
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
                    "the same attention spec.")
133
134
135
136
137
        assert (
            (merged_spec.sliding_window is not None) +
            (merged_spec.attention_chunk_size is not None) <= 1
        ), ("Model with both sliding window layers and chunked local attention "
            "layers is not supported.")
138
139
        return merged_spec

140

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@dataclass(frozen=True)
class MLAAttentionSpec(FullAttentionSpec):
    # TODO(Lucas/Chen): less hacky way to do this
    cache_dtype_str: Optional[str] = None

    @property
    def page_size_bytes(self) -> int:
        if self.cache_dtype_str == "fp8_ds_mla":
            # See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
            #  for details.
            return self.block_size * 656
        return self.block_size * self.num_kv_heads * self.head_size \
                * get_dtype_size(self.dtype)

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be "
            "MLAAttentionSpec.")
        cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
        assert len(cache_dtype_str_set) == 1, (
            "All attention layers in the same KV cache group must use the same "
            "quantization method.")
        return cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            cache_dtype_str=cache_dtype_str_set.pop(),
        )


173
@dataclass(frozen=True)
174
175
176
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

177
178
179
180
181
182
183
184
185
186
187
188
189
190
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)

        # During chunked prefill, we allocate KV cache for at most
        # `self.attention_chunk_size` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
        num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
                         max_model_len)

        return cdiv(num_tokens, self.block_size) * self.page_size_bytes

191

192
@dataclass(frozen=True)
193
194
195
196
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
197
198
        assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
            "DCP not support sliding window."
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
        num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
                         max_model_len)

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
215
216


217
@dataclass(frozen=True)
Chen Zhang's avatar
Chen Zhang committed
218
219
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
220
    dtypes: tuple[torch.dtype]
221
    page_size_padded: Optional[int] = None
222
    mamba_type: str = "mamba2"
223
    num_speculative_blocks: int = 0
Chen Zhang's avatar
Chen Zhang committed
224
225
226

    @property
    def page_size_bytes(self) -> int:
227
228
229
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
            for (shape, dtype) in zip(self.shapes, self.dtypes))
230
231
232
233
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size
Chen Zhang's avatar
Chen Zhang committed
234
235

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
236
237
        max_model_len = vllm_config.model_config.max_model_len
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes
Chen Zhang's avatar
Chen Zhang committed
238
239


240
241
242
243
244
245
246
247
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0


248
249
250
251
252
253
254
255
256
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
    """
    KV cache spec for cross-attention layers in encoder-decoder models.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # For cross-attention, we need to cache encoder states
        # Get encoder length (e.g., 1500 for Whisper).
257
258
        max_encoder_len = vllm_config.scheduler_config.\
            max_num_encoder_input_tokens
259
260
261
        return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes


262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
    """
    A KV cache spec for multiple layers with the same type of attention. Here,
    same types means always need the same number of token slots. For example,
    sliding window attentions with different window sizes are not the same type
    and should not be merged into one UniformTypeKVCacheSpecs.
    """
    kv_cache_specs: dict[str, KVCacheSpec]

    @property
    def page_size_bytes(self) -> int:
        return sum(spec.page_size_bytes
                   for spec in self.kv_cache_specs.values())

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_num_pages = max(
            cdiv(spec.max_memory_usage_bytes(vllm_config),
                 spec.page_size_bytes)
            for spec in self.kv_cache_specs.values())
        return max_num_pages * self.page_size_bytes

    @classmethod
    def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
        """
        Whether all layers have the same type of KV cache spec.
        """
        block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
        if len(block_sizes) > 1:
            # Different block sizes, not uniform.
            return False
        one_spec = next(iter(kv_cache_specs.values()))
294
295
296
297
298
        if isinstance(one_spec, FullAttentionSpec):
            return all(
                isinstance(spec, FullAttentionSpec)
                for spec in kv_cache_specs.values())
        elif isinstance(one_spec, CrossAttentionSpec):
299
            return all(
300
                isinstance(spec, CrossAttentionSpec)
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
                for spec in kv_cache_specs.values())
        elif isinstance(one_spec, SlidingWindowSpec):
            return all(
                isinstance(spec, SlidingWindowSpec)
                and spec.sliding_window == one_spec.sliding_window
                for spec in kv_cache_specs.values())
        elif isinstance(one_spec, ChunkedLocalAttentionSpec):
            return all(
                isinstance(spec, ChunkedLocalAttentionSpec)
                and spec.attention_chunk_size == one_spec.attention_chunk_size
                for spec in kv_cache_specs.values())
        elif isinstance(one_spec, MambaSpec):
            return all(
                isinstance(spec, MambaSpec) and spec.num_speculative_blocks ==
                one_spec.num_speculative_blocks
                for spec in kv_cache_specs.values())
        else:
            # NOTE(Chen): Please add new branches for new KV cache spec types.
            raise NotImplementedError(
                f"Unsupported KV cache spec type: {type(one_spec)}")

    @classmethod
    def from_specs(cls, kv_cache_specs: dict[str,
                                             KVCacheSpec]) -> Optional[Self]:
        """
        Return a SameTypeKVCacheSpecs object if all layers have the same type
        of KV cache spec. Return None if not.
        """
        if cls.is_uniform_type(kv_cache_specs):
            block_size = next(iter(kv_cache_specs.values())).block_size
            return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
        else:
            return None


336
337
338
@dataclass
class KVCacheTensor:
    """
339
    A class for specifying how the workers should initialize the KV cache.
340
    """
341
342
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor
343
344


345
346
347
348
349
350
351
352
353
354
355
356
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec


357
358
359
360
361
362
363
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
    """The number of KV cache blocks"""
    num_blocks: int
364
365
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_tensors: list[KVCacheTensor]
366
    """
367
    The kv cache groups of the model.
368
369
370
371
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
372
    """
373
    kv_cache_groups: list[KVCacheGroupSpec]