kv_cache_interface.py 18.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import copy
5
from dataclasses import dataclass, fields, replace
Chen Zhang's avatar
Chen Zhang committed
6
from math import prod
7
8

import torch
9
from typing_extensions import Self
10

11
from vllm.config import VllmConfig
12
from vllm.logger import init_logger
13
from vllm.utils.math_utils import cdiv
14
from vllm.utils.torch_utils import get_dtype_size
15
16
17
18

logger = init_logger(__name__)


19
@dataclass(frozen=True)
20
class KVCacheSpec:
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

38
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
39
        """
40
        The maximum possible memory usage of this KV cache in bytes.
41
42

        Returns:
43
            The KV cache size in bytes
44
45
46
        """
        raise NotImplementedError

47
48
49
50
51
52
    def copy_with_new_block_size(self, block_size: int) -> Self:
        """
        Create a new KVCacheSpec from self but replacing the block size.
        """
        return replace(self, block_size=block_size)

53
54
55
56
57
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
58
        assert all(spec == specs[0] for spec in specs[1:]), (
59
60
            "All layers in the same KV cache group must be the same."
        )
61
62
        return copy.deepcopy(specs[0])

63

64
@dataclass(frozen=True, kw_only=True)
65
class AttentionSpec(KVCacheSpec):
66
67
68
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype
69
    page_size_padded: int | None = None
70
71
72

    @property
    def page_size_bytes(self) -> int:
73
74
75
76
77
78
79
80
        real_page_size = self.real_page_size_bytes
        if self.page_size_padded is not None:
            assert self.page_size_padded >= real_page_size
            return self.page_size_padded
        return real_page_size

    @property
    def real_page_size_bytes(self) -> int:
81
82
83
84
85
86
87
        return (
            2
            * self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
88

89

90
@dataclass(frozen=True, kw_only=True)
91
class FullAttentionSpec(AttentionSpec):
92
    """
93
94
95
96
    When hybrid allocator is disabled and the model contains both full
    attention layers and sliding window attention layers, sliding
    window attention are regarded as full attention in KV cache manager
    (blocks are allocated for all tokens), while computed as sliding window
97
98
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
99
100
    """

101
102
    head_size_v: int | None = None

103
104
    sliding_window: int | None = None
    """
105
106
    Default to None for not using sliding window attention.
    """
107
    attention_chunk_size: int | None = None
108

109
110
111
112
    def __post_init__(self):
        if self.head_size_v is None:
            object.__setattr__(self, "head_size_v", self.head_size)

113
114
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
115
        dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
116
        pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
117
118
        # Note(hc): each dcp rank only need save
        # (max_model_len//dcp_world_size) tokens locally.
119
120
        if dcp_world_size * pcp_world_size > 1:
            max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
121
122
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

123
    @classmethod
124
    def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
125
126
127
128
129
130
131
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
132
133
                "same window size."
            )
134

135
136
137
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
138
        Merge a list of FullAttentionSpec objects into a single
139
140
        FullAttentionSpec object.
        """
141
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
142
143
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )
144

145
146
147
148
149
150
151
152
        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
153
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
154
155
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
156
157
158
159
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
160
            head_size_v=specs[0].head_size_v,
161
            dtype=specs[0].dtype,
162
            page_size_padded=specs[0].page_size_padded,
163
164
165
166
167
168
169
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
170
171
172
173
174
175
176
177
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
178
179
        return merged_spec

180
    @property
181
    def real_page_size_bytes(self) -> int:
182
183
184
185
186
187
188
        return (
            self.block_size
            * self.num_kv_heads
            * (self.head_size + self.head_size_v)
            * get_dtype_size(self.dtype)
        )

189

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@dataclass(frozen=True, kw_only=True)
class TQFullAttentionSpec(FullAttentionSpec):
    """FullAttentionSpec with TQ-aware page size.

    Python equivalent of the C++ TQ4FullAttentionSpec. Overrides
    real_page_size_bytes to use TQ slot bytes instead of the raw
    head_size * dtype formula.
    """

    tq_slot_size: int = 0

    @property
    def real_page_size_bytes(self) -> int:
        if self.tq_slot_size > 0:
            return self.block_size * self.num_kv_heads * self.tq_slot_size
        return super().real_page_size_bytes

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        merged = super().merge(specs)
        assert all(s.tq_slot_size == specs[0].tq_slot_size for s in specs), (
            "All TQ layers in the same KV cache group must use the same tq_slot_size."
        )
        return replace(merged, tq_slot_size=specs[0].tq_slot_size)


216
@dataclass(frozen=True, kw_only=True)
217
218
class MLAAttentionSpec(FullAttentionSpec):
    # TODO(Lucas/Chen): less hacky way to do this
219
    cache_dtype_str: str | None = None
220
221

    @property
222
    def real_page_size_bytes(self) -> int:
223
224
225
226
        if self.cache_dtype_str == "fp8_ds_mla":
            # See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
            #  for details.
            return self.block_size * 656
227
228
229
230
231
232
        return (
            self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
233
234
235
236

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
237
238
            "All attention layers in the same KV cache group must be MLAAttentionSpec."
        )
239
240
241
        cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
        assert len(cache_dtype_str_set) == 1, (
            "All attention layers in the same KV cache group must use the same "
242
243
            "quantization method."
        )
244
245
246
247
248
        return cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
249
            page_size_padded=specs[0].page_size_padded,
250
251
252
253
            cache_dtype_str=cache_dtype_str_set.pop(),
        )


254
@dataclass(frozen=True, kw_only=True)
255
256
257
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

258
259
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
260
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
261
262
263
264
265

        # During chunked prefill, we allocate KV cache for at most
        # `self.attention_chunk_size` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
266
267
268
        num_tokens = min(
            self.attention_chunk_size + max_num_batched_tokens, max_model_len
        )
269
270
271

        return cdiv(num_tokens, self.block_size) * self.page_size_bytes

272

273
@dataclass(frozen=True, kw_only=True)
274
275
276
277
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
278
        assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
279
            "DCP not support sliding window."
280
        )
281
        max_model_len = vllm_config.model_config.max_model_len
282
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
283
284
285
286
287

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
288
289
290
        num_tokens = min(
            self.sliding_window - 1 + max_num_batched_tokens, max_model_len
        )
291
292
293
294
295
296

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
297
298


299
@dataclass(frozen=True)
Chen Zhang's avatar
Chen Zhang committed
300
301
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
302
    dtypes: tuple[torch.dtype]
303
    page_size_padded: int | None = None
304
    mamba_type: str = "mamba2"
305
    mamba_cache_mode: str = "none"
306
    num_speculative_blocks: int = 0
Chen Zhang's avatar
Chen Zhang committed
307
308
309

    @property
    def page_size_bytes(self) -> int:
310
311
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
312
313
            for (shape, dtype) in zip(self.shapes, self.dtypes)
        )
314
315
316
317
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size
Chen Zhang's avatar
Chen Zhang committed
318
319

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
320
321
322
323
324
325
326
        if vllm_config.cache_config.mamba_cache_mode == "all":
            max_model_len = vllm_config.model_config.max_model_len
            return cdiv(max_model_len, self.block_size) * self.page_size_bytes
        elif vllm_config.cache_config.mamba_cache_mode == "align":
            return self.page_size_bytes * (2 + self.num_speculative_blocks)
        else:
            return self.page_size_bytes * (1 + self.num_speculative_blocks)
Chen Zhang's avatar
Chen Zhang committed
327
328


329
330
331
332
333
334
335
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0


336
337
338
339
340
341
342
343
344
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
    """
    KV cache spec for cross-attention layers in encoder-decoder models.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # For cross-attention, we need to cache encoder states
        # Get encoder length (e.g., 1500 for Whisper).
345
        max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens
346
347
348
        return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes


349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
@dataclass(frozen=True)
class SinkFullAttentionSpec(FullAttentionSpec):
    sink_len: int | None = None

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single
        FullAttentionSpec object.
        """
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )

        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            head_size_v=specs[0].head_size_v,
            sink_len=specs[0].sink_len,
            dtype=specs[0].dtype,
381
            page_size_padded=specs[0].page_size_padded,
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
        return merged_spec


400
401
402
403
404
405
406
407
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
    """
    A KV cache spec for multiple layers with the same type of attention. Here,
    same types means always need the same number of token slots. For example,
    sliding window attentions with different window sizes are not the same type
    and should not be merged into one UniformTypeKVCacheSpecs.
    """
408

409
410
411
412
    kv_cache_specs: dict[str, KVCacheSpec]

    @property
    def page_size_bytes(self) -> int:
413
        return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
414
415
416

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_num_pages = max(
417
418
419
            cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
            for spec in self.kv_cache_specs.values()
        )
420
421
422
423
424
425
426
427
428
429
430
431
        return max_num_pages * self.page_size_bytes

    @classmethod
    def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
        """
        Whether all layers have the same type of KV cache spec.
        """
        block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
        if len(block_sizes) > 1:
            # Different block sizes, not uniform.
            return False
        one_spec = next(iter(kv_cache_specs.values()))
432
433
        if isinstance(one_spec, FullAttentionSpec):
            return all(
434
435
                isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
            )
436
        elif isinstance(one_spec, CrossAttentionSpec):
437
            return all(
438
439
                isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
            )
440
441
442
443
        elif isinstance(one_spec, SlidingWindowSpec):
            return all(
                isinstance(spec, SlidingWindowSpec)
                and spec.sliding_window == one_spec.sliding_window
444
445
                for spec in kv_cache_specs.values()
            )
446
447
448
449
        elif isinstance(one_spec, ChunkedLocalAttentionSpec):
            return all(
                isinstance(spec, ChunkedLocalAttentionSpec)
                and spec.attention_chunk_size == one_spec.attention_chunk_size
450
451
                for spec in kv_cache_specs.values()
            )
452
453
        elif isinstance(one_spec, MambaSpec):
            return all(
454
455
456
457
                isinstance(spec, MambaSpec)
                and spec.num_speculative_blocks == one_spec.num_speculative_blocks
                for spec in kv_cache_specs.values()
            )
458
459
460
        else:
            # NOTE(Chen): Please add new branches for new KV cache spec types.
            raise NotImplementedError(
461
462
                f"Unsupported KV cache spec type: {type(one_spec)}"
            )
463
464

    @classmethod
465
    def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None:
466
467
468
469
470
471
472
473
474
475
476
        """
        Return a SameTypeKVCacheSpecs object if all layers have the same type
        of KV cache spec. Return None if not.
        """
        if cls.is_uniform_type(kv_cache_specs):
            block_size = next(iter(kv_cache_specs.values())).block_size
            return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
        else:
            return None


477
478
479
@dataclass
class KVCacheTensor:
    """
480
    A class for specifying how the workers should initialize the KV cache.
481
    """
482

483
484
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor
485
486


487
488
489
490
491
492
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
493

494
495
496
497
498
499
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec


500
501
502
503
504
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
505

506
    num_blocks: int
507
    """The number of KV cache blocks"""
508
    kv_cache_tensors: list[KVCacheTensor]
509
510
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_groups: list[KVCacheGroupSpec]
511
    """
512
    The kv cache groups of the model.
513
514
515
516
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
517
    """