kv_cache_interface.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import copy
5
from dataclasses import dataclass, fields, replace
Chen Zhang's avatar
Chen Zhang committed
6
from math import prod
7
8

import torch
9
from typing_extensions import Self
10

11
from vllm.config import VllmConfig
12
from vllm.logger import init_logger
13
from vllm.utils.math_utils import cdiv
14
from vllm.utils.torch_utils import get_dtype_size
15
16
17
18

logger = init_logger(__name__)


19
@dataclass(frozen=True)
20
class KVCacheSpec:
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

38
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
39
        """
40
        The maximum possible memory usage of this KV cache in bytes.
41
42

        Returns:
43
            The KV cache size in bytes
44
45
46
        """
        raise NotImplementedError

47
48
49
50
51
52
    def copy_with_new_block_size(self, block_size: int) -> Self:
        """
        Create a new KVCacheSpec from self but replacing the block size.
        """
        return replace(self, block_size=block_size)

53
54
55
56
57
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
58
        assert all(spec == specs[0] for spec in specs[1:]), (
59
60
            "All layers in the same KV cache group must be the same."
        )
61
62
        return copy.deepcopy(specs[0])

63

64
@dataclass(frozen=True)
65
class AttentionSpec(KVCacheSpec):
66
67
68
69
70
71
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype

    @property
    def page_size_bytes(self) -> int:
72
73
74
75
76
77
78
        return (
            2
            * self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
79

80

81
@dataclass(frozen=True)
82
class FullAttentionSpec(AttentionSpec):
83
84
    sliding_window: int | None = None
    attention_chunk_size: int | None = None
85
86
87
88
89
90
91
92
93
    """
    When hybrid allocator is disabled and the model contains both full 
    attention layers and sliding window attention layers, sliding 
    window attention are regarded as full attention in KV cache manager 
    (blocks are allocated for all tokens), while computed as sliding window 
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
    Default to None for not using sliding window attention.
    """
94
95
96

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
97
        dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
98
        pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
99
100
        # Note(hc): each dcp rank only need save
        # (max_model_len//dcp_world_size) tokens locally.
101
102
        if dcp_world_size * pcp_world_size > 1:
            max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
103
104
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

105
    @classmethod
106
    def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
107
108
109
110
111
112
113
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
114
115
                "same window size."
            )
116

117
118
119
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
120
        Merge a list of FullAttentionSpec objects into a single
121
122
        FullAttentionSpec object.
        """
123
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
124
125
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )
126

127
128
129
130
131
132
133
134
        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
135
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
136
137
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
138
139
140
141
142
143
144
145
146
147
148
149
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
150
151
152
153
154
155
156
157
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
158
159
        return merged_spec

160

161
162
163
@dataclass(frozen=True)
class MLAAttentionSpec(FullAttentionSpec):
    # TODO(Lucas/Chen): less hacky way to do this
164
    cache_dtype_str: str | None = None
165
166
167
168
169
170
171

    @property
    def page_size_bytes(self) -> int:
        if self.cache_dtype_str == "fp8_ds_mla":
            # See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
            #  for details.
            return self.block_size * 656
172
173
174
175
176
177
        return (
            self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
178
179
180
181

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
182
183
            "All attention layers in the same KV cache group must be MLAAttentionSpec."
        )
184
185
186
        cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
        assert len(cache_dtype_str_set) == 1, (
            "All attention layers in the same KV cache group must use the same "
187
188
            "quantization method."
        )
189
190
191
192
193
194
195
196
197
        return cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            cache_dtype_str=cache_dtype_str_set.pop(),
        )


198
@dataclass(frozen=True)
199
200
201
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

202
203
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
204
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
205
206
207
208
209

        # During chunked prefill, we allocate KV cache for at most
        # `self.attention_chunk_size` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
210
211
212
        num_tokens = min(
            self.attention_chunk_size + max_num_batched_tokens, max_model_len
        )
213
214
215

        return cdiv(num_tokens, self.block_size) * self.page_size_bytes

216

217
@dataclass(frozen=True)
218
219
220
221
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
222
        assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
223
            "DCP not support sliding window."
224
        )
225
        max_model_len = vllm_config.model_config.max_model_len
226
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
227
228
229
230
231

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
232
233
234
        num_tokens = min(
            self.sliding_window - 1 + max_num_batched_tokens, max_model_len
        )
235
236
237
238
239
240

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
241
242


243
@dataclass(frozen=True)
Chen Zhang's avatar
Chen Zhang committed
244
245
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
246
    dtypes: tuple[torch.dtype]
247
    page_size_padded: int | None = None
248
    mamba_type: str = "mamba2"
249
    num_speculative_blocks: int = 0
Chen Zhang's avatar
Chen Zhang committed
250
251
252

    @property
    def page_size_bytes(self) -> int:
253
254
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
255
256
            for (shape, dtype) in zip(self.shapes, self.dtypes)
        )
257
258
259
260
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size
Chen Zhang's avatar
Chen Zhang committed
261
262

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
263
264
        max_model_len = vllm_config.model_config.max_model_len
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes
Chen Zhang's avatar
Chen Zhang committed
265
266


267
268
269
270
271
272
273
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0


274
275
276
277
278
279
280
281
282
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
    """
    KV cache spec for cross-attention layers in encoder-decoder models.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # For cross-attention, we need to cache encoder states
        # Get encoder length (e.g., 1500 for Whisper).
283
        max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens
284
285
286
        return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes


287
288
289
290
291
292
293
294
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
    """
    A KV cache spec for multiple layers with the same type of attention. Here,
    same types means always need the same number of token slots. For example,
    sliding window attentions with different window sizes are not the same type
    and should not be merged into one UniformTypeKVCacheSpecs.
    """
295

296
297
298
299
    kv_cache_specs: dict[str, KVCacheSpec]

    @property
    def page_size_bytes(self) -> int:
300
        return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
301
302
303

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_num_pages = max(
304
305
306
            cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
            for spec in self.kv_cache_specs.values()
        )
307
308
309
310
311
312
313
314
315
316
317
318
        return max_num_pages * self.page_size_bytes

    @classmethod
    def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
        """
        Whether all layers have the same type of KV cache spec.
        """
        block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
        if len(block_sizes) > 1:
            # Different block sizes, not uniform.
            return False
        one_spec = next(iter(kv_cache_specs.values()))
319
320
        if isinstance(one_spec, FullAttentionSpec):
            return all(
321
322
                isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
            )
323
        elif isinstance(one_spec, CrossAttentionSpec):
324
            return all(
325
326
                isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
            )
327
328
329
330
        elif isinstance(one_spec, SlidingWindowSpec):
            return all(
                isinstance(spec, SlidingWindowSpec)
                and spec.sliding_window == one_spec.sliding_window
331
332
                for spec in kv_cache_specs.values()
            )
333
334
335
336
        elif isinstance(one_spec, ChunkedLocalAttentionSpec):
            return all(
                isinstance(spec, ChunkedLocalAttentionSpec)
                and spec.attention_chunk_size == one_spec.attention_chunk_size
337
338
                for spec in kv_cache_specs.values()
            )
339
340
        elif isinstance(one_spec, MambaSpec):
            return all(
341
342
343
344
                isinstance(spec, MambaSpec)
                and spec.num_speculative_blocks == one_spec.num_speculative_blocks
                for spec in kv_cache_specs.values()
            )
345
346
347
        else:
            # NOTE(Chen): Please add new branches for new KV cache spec types.
            raise NotImplementedError(
348
349
                f"Unsupported KV cache spec type: {type(one_spec)}"
            )
350
351

    @classmethod
352
    def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None:
353
354
355
356
357
358
359
360
361
362
363
        """
        Return a SameTypeKVCacheSpecs object if all layers have the same type
        of KV cache spec. Return None if not.
        """
        if cls.is_uniform_type(kv_cache_specs):
            block_size = next(iter(kv_cache_specs.values())).block_size
            return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
        else:
            return None


364
365
366
@dataclass
class KVCacheTensor:
    """
367
    A class for specifying how the workers should initialize the KV cache.
368
    """
369

370
371
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor
372
373


374
375
376
377
378
379
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
380

381
382
383
384
385
386
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec


387
388
389
390
391
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
392

393
394
    """The number of KV cache blocks"""
    num_blocks: int
395
396
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_tensors: list[KVCacheTensor]
397
    """
398
    The kv cache groups of the model.
399
400
401
402
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
403
    """
404
    kv_cache_groups: list[KVCacheGroupSpec]