kv_cache_interface.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import copy
5
from dataclasses import dataclass, fields
Chen Zhang's avatar
Chen Zhang committed
6
from math import prod
7
8

import torch
9
from typing_extensions import Self
10

11
from vllm.config import VllmConfig
12
13
14
15
16
17
from vllm.logger import init_logger
from vllm.utils import cdiv, get_dtype_size

logger = init_logger(__name__)


18
@dataclass(frozen=True)
19
class KVCacheSpec:
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

37
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
38
        """
39
        The maximum possible memory usage of this KV cache in bytes.
40
41

        Returns:
42
            The KV cache size in bytes
43
44
45
        """
        raise NotImplementedError

46
47
48
49
50
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
51
        assert all(spec == specs[0] for spec in specs[1:]), (
52
53
            "All layers in the same KV cache group must be the same."
        )
54
55
        return copy.deepcopy(specs[0])

56

57
@dataclass(frozen=True)
58
class AttentionSpec(KVCacheSpec):
59
60
61
62
63
64
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype

    @property
    def page_size_bytes(self) -> int:
65
66
67
68
69
70
71
        return (
            2
            * self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
72

73

74
@dataclass(frozen=True)
75
class FullAttentionSpec(AttentionSpec):
76
77
    sliding_window: int | None = None
    attention_chunk_size: int | None = None
78
79
80
81
82
83
84
85
86
    """
    When hybrid allocator is disabled and the model contains both full 
    attention layers and sliding window attention layers, sliding 
    window attention are regarded as full attention in KV cache manager 
    (blocks are allocated for all tokens), while computed as sliding window 
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
    Default to None for not using sliding window attention.
    """
87
88
89

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
90
        dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
91
92
93
94
        # Note(hc): each dcp rank only need save
        # (max_model_len//dcp_world_size) tokens locally.
        if dcp_world_size > 1:
            max_model_len = cdiv(max_model_len, dcp_world_size)
95
96
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

97
    @classmethod
98
    def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
99
100
101
102
103
104
105
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
106
107
                "same window size."
            )
108

109
110
111
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
112
        Merge a list of FullAttentionSpec objects into a single
113
114
        FullAttentionSpec object.
        """
115
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
116
117
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )
118

119
120
121
122
123
124
125
126
        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
127
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
128
129
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
130
131
132
133
134
135
136
137
138
139
140
141
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
142
143
144
145
146
147
148
149
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
150
151
        return merged_spec

152

153
154
155
@dataclass(frozen=True)
class MLAAttentionSpec(FullAttentionSpec):
    # TODO(Lucas/Chen): less hacky way to do this
156
    cache_dtype_str: str | None = None
157
158
159
160
161
162
163

    @property
    def page_size_bytes(self) -> int:
        if self.cache_dtype_str == "fp8_ds_mla":
            # See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
            #  for details.
            return self.block_size * 656
164
165
166
167
168
169
        return (
            self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
170
171
172
173

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
174
175
            "All attention layers in the same KV cache group must be MLAAttentionSpec."
        )
176
177
178
        cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
        assert len(cache_dtype_str_set) == 1, (
            "All attention layers in the same KV cache group must use the same "
179
180
            "quantization method."
        )
181
182
183
184
185
186
187
188
189
        return cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            cache_dtype_str=cache_dtype_str_set.pop(),
        )


190
@dataclass(frozen=True)
191
192
193
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

194
195
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
196
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
197
198
199
200
201

        # During chunked prefill, we allocate KV cache for at most
        # `self.attention_chunk_size` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
202
203
204
        num_tokens = min(
            self.attention_chunk_size + max_num_batched_tokens, max_model_len
        )
205
206
207

        return cdiv(num_tokens, self.block_size) * self.page_size_bytes

208

209
@dataclass(frozen=True)
210
211
212
213
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
214
        assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
215
            "DCP not support sliding window."
216
        )
217
        max_model_len = vllm_config.model_config.max_model_len
218
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
219
220
221
222
223

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
224
225
226
        num_tokens = min(
            self.sliding_window - 1 + max_num_batched_tokens, max_model_len
        )
227
228
229
230
231
232

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
233
234


235
@dataclass(frozen=True)
Chen Zhang's avatar
Chen Zhang committed
236
237
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
238
    dtypes: tuple[torch.dtype]
239
    page_size_padded: int | None = None
240
    mamba_type: str = "mamba2"
241
    num_speculative_blocks: int = 0
Chen Zhang's avatar
Chen Zhang committed
242
243
244

    @property
    def page_size_bytes(self) -> int:
245
246
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
247
248
            for (shape, dtype) in zip(self.shapes, self.dtypes)
        )
249
250
251
252
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size
Chen Zhang's avatar
Chen Zhang committed
253
254

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
255
256
        max_model_len = vllm_config.model_config.max_model_len
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes
Chen Zhang's avatar
Chen Zhang committed
257
258


259
260
261
262
263
264
265
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0


266
267
268
269
270
271
272
273
274
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
    """
    KV cache spec for cross-attention layers in encoder-decoder models.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # For cross-attention, we need to cache encoder states
        # Get encoder length (e.g., 1500 for Whisper).
275
        max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens
276
277
278
        return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes


279
280
281
282
283
284
285
286
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
    """
    A KV cache spec for multiple layers with the same type of attention. Here,
    same types means always need the same number of token slots. For example,
    sliding window attentions with different window sizes are not the same type
    and should not be merged into one UniformTypeKVCacheSpecs.
    """
287

288
289
290
291
    kv_cache_specs: dict[str, KVCacheSpec]

    @property
    def page_size_bytes(self) -> int:
292
        return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
293
294
295

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_num_pages = max(
296
297
298
            cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
            for spec in self.kv_cache_specs.values()
        )
299
300
301
302
303
304
305
306
307
308
309
310
        return max_num_pages * self.page_size_bytes

    @classmethod
    def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
        """
        Whether all layers have the same type of KV cache spec.
        """
        block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
        if len(block_sizes) > 1:
            # Different block sizes, not uniform.
            return False
        one_spec = next(iter(kv_cache_specs.values()))
311
312
        if isinstance(one_spec, FullAttentionSpec):
            return all(
313
314
                isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
            )
315
        elif isinstance(one_spec, CrossAttentionSpec):
316
            return all(
317
318
                isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
            )
319
320
321
322
        elif isinstance(one_spec, SlidingWindowSpec):
            return all(
                isinstance(spec, SlidingWindowSpec)
                and spec.sliding_window == one_spec.sliding_window
323
324
                for spec in kv_cache_specs.values()
            )
325
326
327
328
        elif isinstance(one_spec, ChunkedLocalAttentionSpec):
            return all(
                isinstance(spec, ChunkedLocalAttentionSpec)
                and spec.attention_chunk_size == one_spec.attention_chunk_size
329
330
                for spec in kv_cache_specs.values()
            )
331
332
        elif isinstance(one_spec, MambaSpec):
            return all(
333
334
335
336
                isinstance(spec, MambaSpec)
                and spec.num_speculative_blocks == one_spec.num_speculative_blocks
                for spec in kv_cache_specs.values()
            )
337
338
339
        else:
            # NOTE(Chen): Please add new branches for new KV cache spec types.
            raise NotImplementedError(
340
341
                f"Unsupported KV cache spec type: {type(one_spec)}"
            )
342
343

    @classmethod
344
    def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None:
345
346
347
348
349
350
351
352
353
354
355
        """
        Return a SameTypeKVCacheSpecs object if all layers have the same type
        of KV cache spec. Return None if not.
        """
        if cls.is_uniform_type(kv_cache_specs):
            block_size = next(iter(kv_cache_specs.values())).block_size
            return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
        else:
            return None


356
357
358
@dataclass
class KVCacheTensor:
    """
359
    A class for specifying how the workers should initialize the KV cache.
360
    """
361

362
363
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor
364
365


366
367
368
369
370
371
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
372

373
374
375
376
377
378
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec


379
380
381
382
383
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
384

385
386
    """The number of KV cache blocks"""
    num_blocks: int
387
388
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_tensors: list[KVCacheTensor]
389
    """
390
    The kv cache groups of the model.
391
392
393
394
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
395
    """
396
    kv_cache_groups: list[KVCacheGroupSpec]